diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 42201afc86..f5a38c2670 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError
@@ -101,7 +101,9 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
+ def on_PUT(
+ self, request: SynapseRequest, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 0443f4571c..a0971ce994 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -16,7 +16,7 @@
"""
import logging
import re
-from typing import Iterable, Pattern
+from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
)
-def interactive_auth_handler(orig):
+C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])
+
+
+def interactive_auth_handler(orig: C) -> C:
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns an Awaitable (errcode, body) response
@@ -91,10 +94,10 @@ def interactive_auth_handler(orig):
await self.auth_handler.check_auth
"""
- async def wrapped(*args, **kwargs):
+ async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result
- return wrapped
+ return cast(C, wrapped)
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index fb5ad2906e..aefaaa8ae8 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -16,9 +16,11 @@
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@@ -28,15 +30,17 @@ from synapse.api.errors import (
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed, validate_email
@@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=self.config.email_password_reset_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet):
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these
@@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
+ requester = None
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
try:
@@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
- password_hash = await self.auth_handler.hash(new_password)
+ new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
- password_hash,
+ new_password_hash,
)
raise
user_id = requester.user.to_string()
else:
- requester = None
try:
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
@@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
- password_hash = await self.auth_handler.hash(new_password)
+ new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
- password_hash,
+ new_password_hash,
)
raise
@@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet):
# If we have a password in this request, prefer it. Otherwise, use the
# password hash from an earlier request.
if new_password:
- password_hash = await self.auth_handler.hash(new_password)
+ password_hash: Optional[str] = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
@@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.config = hs.config
@@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
template_text=self.config.email_add_threepid_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
@@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
"/add_threepid/email/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
@@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config.email_add_threepid_template_failure_html
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> None:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
if not self.config.account_threepid_delegate_msisdn:
raise SynapseError(
400,
@@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet):
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet):
class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
@@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
-def assert_valid_next_link(hs: "HomeServer", next_link: str):
+def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
"""
Raises a SynapseError if a given next_link value is invalid
@@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
response = {"user_id": requester.user.to_string()}
@@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet):
return 200, response
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 7517e9304e..d1badbdf3b 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, account_data_type):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet):
return 200, {}
- async def on_GET(self, request, user_id, account_data_type):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
@@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet):
"/account_data/(?P<account_data_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, account_data_type):
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet):
return 200, {}
- async def on_GET(self, request, user_id, room_id, account_data_type):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
@@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet):
return 200, event
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index c3667ff8aa..a7e9aa3e9b 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -15,7 +15,7 @@
import logging
from functools import wraps
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request
@@ -43,14 +43,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _validate_group_id(f):
+def _validate_group_id(
+ f: Callable[..., Awaitable[Tuple[int, JsonDict]]]
+) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]:
"""Wrapper to validate the form of the group ID.
Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
"""
@wraps(f)
- def wrapper(self, request: Request, group_id: str, *args, **kwargs):
+ def wrapper(
+ self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any
+ ) -> Awaitable[Tuple[int, JsonDict]]:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@@ -156,7 +160,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id: str,
category_id: Optional[str],
room_id: str,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -188,7 +192,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -451,7 +455,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -674,7 +678,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -706,7 +710,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: SynapseRequest, group_id, user_id
+ self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -738,7 +742,7 @@ class GroupAdminUsersKickServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: SynapseRequest, group_id, user_id
+ self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 68fb08d0ba..0152a0c66a 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from twisted.web.server import Request
@@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id}
- def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
+ def on_PUT(
+ self, request: Request, room_identifier: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 4be502a77b..bcba106bdd 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -79,7 +79,6 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
- self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._msc2918_enabled = hs.config.access_token_lifetime is not None
self.auth = hs.get_auth()
@@ -111,7 +110,7 @@ class LoginRestServlet(RestServlet):
_load_sso_handlers(hs)
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- flows = []
+ flows: List[JsonDict] = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
@@ -122,25 +121,15 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- sso_flow: JsonDict = {
- "type": LoginRestServlet.SSO_TYPE,
- "identity_providers": [
- _get_auth_flow_dict_for_idp(
- idp,
- )
- for idp in self._sso_handler.get_identity_providers().values()
- ],
- }
-
- if self._msc2858_enabled:
- # backwards-compatibility support for clients which don't
- # support the stable API yet
- sso_flow["org.matrix.msc2858.identity_providers"] = [
- _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
- for idp in self._sso_handler.get_identity_providers().values()
- ]
-
- flows.append(sso_flow)
+ flows.append(
+ {
+ "type": LoginRestServlet.SSO_TYPE,
+ "identity_providers": [
+ _get_auth_flow_dict_for_idp(idp)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ],
+ }
+ )
# While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
@@ -433,9 +422,7 @@ class LoginRestServlet(RestServlet):
return result
-def _get_auth_flow_dict_for_idp(
- idp: SsoIdentityProvider, use_unstable_brands: bool = False
-) -> JsonDict:
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
@@ -443,17 +430,12 @@ def _get_auth_flow_dict_for_idp(
Args:
idp: the identity provider to describe
- use_unstable_brands: whether we should use brand identifiers suitable
- for the unstable API
"""
e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
- # use the stable brand identifier if the unstable identifier isn't defined.
- if use_unstable_brands and idp.unstable_idp_brand:
- e["brand"] = idp.unstable_idp_brand
return e
@@ -504,25 +486,8 @@ class SsoRedirectServlet(RestServlet):
# register themselves with the main SSOHandler.
_load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
- self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
- def register(self, http_server: HttpServer) -> None:
- super().register(http_server)
- if self._msc2858_enabled:
- # expose additional endpoint for MSC2858 support: backwards-compat support
- # for clients which don't yet support the stable endpoints.
- http_server.register_paths(
- "GET",
- client_patterns(
- "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
- releases=(),
- unstable=True,
- ),
- self.on_GET,
- self.__class__.__name__,
- )
-
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 702b351d18..fb3211bf3a 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,22 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
+
+import attr
+
from synapse.api.errors import (
NotFoundError,
StoreError,
SynapseError,
UnrecognizedRequestError,
)
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_value_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RuleSpec:
+ scope: str
+ template: str
+ rule_id: str
+ attr: Optional[str]
class PushRuleRestServlet(RestServlet):
@@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet):
"Unrecognised request: You probably wanted a trailing slash"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet):
self._users_new_default_push_rules = hs.config.users_new_default_push_rules
- async def on_PUT(self, request, path):
+ async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
- if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
+ if "/" in spec.rule_id or "\\" in spec.rule_id:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
- if "attr" in spec:
+ if spec.attr:
await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
return 200, {}
- if spec["rule_id"].startswith("."):
+ if spec.rule_id.startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try:
(conditions, actions) = _rule_tuple_from_request_object(
- spec["template"], spec["rule_id"], content
+ spec.template, spec.rule_id, content
)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
@@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
- async def on_DELETE(self, request, path):
+ async def on_DELETE(
+ self, request: SynapseRequest, path: str
+ ) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
@@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet):
else:
raise
- async def on_GET(self, request, path):
+ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet):
rules = format_push_rules_for_user(requester.user, rules)
- path = path.split("/")[1:]
+ path_parts = path.split("/")[1:]
- if path == []:
+ if path_parts == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == "":
+ if path_parts[0] == "":
return 200, rules
- elif path[0] == "global":
- result = _filter_ruleset_with_path(rules["global"], path[1:])
+ elif path_parts[0] == "global":
+ result = _filter_ruleset_with_path(rules["global"], path_parts[1:])
return 200, result
else:
raise UnrecognizedRequestError()
- def notify_user(self, user_id):
+ def notify_user(self, user_id: str) -> None:
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
- async def set_rule_attr(self, user_id, spec, val):
- if spec["attr"] not in ("enabled", "actions"):
+ async def set_rule_attr(
+ self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
+ ) -> None:
+ if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec["rule_id"]
+ rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
- if spec["attr"] == "enabled":
+ if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
@@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- return await self.store.set_push_rule_enabled(
+ await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
- elif spec["attr"] == "actions":
+ elif spec.attr == "actions":
+ if not isinstance(val, dict):
+ raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
+ if not isinstance(actions, list):
+ raise SynapseError(400, "Value for 'actions' must be dict")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec["rule_id"]
+ rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if user_id in self._users_new_default_push_rules:
@@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- return await self.store.set_push_rule_actions(
+ await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
-def _rule_spec_from_path(path):
+def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
Args:
- path (sequence[unicode]): the URL path components.
+ path: the URL path components.
Returns:
- dict: rule spec dict, containing scope/template/rule_id entries,
- and possibly attr.
+ rule spec, containing scope/template/rule_id entries, and possibly attr.
Raises:
UnrecognizedRequestError if the path components cannot be parsed.
@@ -237,17 +262,18 @@ def _rule_spec_from_path(path):
rule_id = path[0]
- spec = {"scope": scope, "template": template, "rule_id": rule_id}
-
path = path[1:]
+ attr = None
if len(path) > 0 and len(path[0]) > 0:
- spec["attr"] = path[0]
+ attr = path[0]
- return spec
+ return RuleSpec(scope, template, rule_id, attr)
-def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
+def _rule_tuple_from_request_object(
+ rule_template: str, rule_id: str, req_obj: JsonDict
+) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]:
if rule_template in ["override", "underride"]:
if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
@@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
return conditions, actions
-def _check_actions(actions):
+def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
@@ -290,7 +316,7 @@ def _check_actions(actions):
raise InvalidRuleException("Unrecognised action")
-def _filter_ruleset_with_path(ruleset, path):
+def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
@@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path):
if r["rule_id"] == rule_id:
the_rule = r
if the_rule is None:
- raise NotFoundError
+ raise NotFoundError()
path = path[1:]
if len(path) == 0:
@@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path):
raise UnrecognizedRequestError()
-def _priority_class_from_spec(spec):
- if spec["template"] not in PRIORITY_CLASS_MAP.keys():
- raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
- pc = PRIORITY_CLASS_MAP[spec["template"]]
+def _priority_class_from_spec(spec: RuleSpec) -> int:
+ if spec.template not in PRIORITY_CLASS_MAP.keys():
+ raise InvalidRuleException("Unknown template: %s" % (spec.template))
+ pc = PRIORITY_CLASS_MAP[spec.template]
return pc
-def _namespaced_rule_id_from_spec(spec):
- return _namespaced_rule_id(spec, spec["rule_id"])
+def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
+ return _namespaced_rule_id(spec, spec.rule_id)
-def _namespaced_rule_id(spec, rule_id):
- return "global/%s/%s" % (spec["template"], rule_id)
+def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
+ return "global/%s/%s" % (spec.template, rule_id)
class InvalidRuleException(Exception):
pass
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index d9ab836cd8..9770413c61 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,13 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet):
"/(?P<event_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id, receipt_type, event_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
@@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 7b5f49d635..8f3dd2a101 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -14,7 +14,9 @@
# limitations under the License.
import logging
import random
-from typing import List, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.server import Request
import synapse
import synapse.api.auth
@@ -29,15 +31,13 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
-from synapse.config.captcha import CaptchaConfig
-from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
@@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html
)
- async def on_GET(self, request, medium):
+ async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@@ -387,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
unstable=True,
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
@@ -402,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
if not self.hs.config.enable_registration:
@@ -419,11 +403,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -445,23 +425,21 @@ class RegisterRestServlet(RestServlet):
)
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False)
- kind = b"user"
- if b"kind" in request.args:
- kind = request.args[b"kind"][0]
+ kind = parse_string(request, "kind", default="user")
- if kind == b"guest":
+ if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr)
return ret
- elif kind != b"user":
+ elif kind != "user":
raise UnrecognizedRequestError(
- "Do not understand membership kind: %s" % (kind.decode("utf8"),)
+ f"Do not understand membership kind: {kind}",
)
if self._msc2918_enabled:
@@ -748,8 +726,12 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
async def _do_appservice_registration(
- self, username, as_token, body, should_issue_refresh_token: bool = False
- ):
+ self,
+ username: str,
+ as_token: str,
+ body: JsonDict,
+ should_issue_refresh_token: bool = False,
+ ) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
)
@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
- ):
+ ) -> JsonDict:
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result
- async def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(
+ self, params: JsonDict, address: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows(
- # technically `config` has to provide *all* of these interfaces, not just one
- config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
- auth_handler: AuthHandler,
+ config: HomeServerConfig, auth_handler: AuthHandler
) -> List[List[str]]:
"""Get a suitable flows list for registration
@@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 0821cd285f..0b0711c03c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
"""
import logging
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
http_server.register_paths(
"POST",
client_patterns(self.PATTERN + "$", releases=()),
@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
self.__class__.__name__,
)
- def on_PUT(self, request, *args, **kwargs):
+ def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
- request, self.on_PUT_or_POST, request, *args, **kwargs
+ request,
+ self.on_PUT_or_POST,
+ request,
+ room_id,
+ parent_id,
+ relation_type,
+ event_type,
+ txn_id,
)
async def on_PUT_or_POST(
- self, request, room_id, parent_id, relation_type, event_type, txn_id=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ key: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return 200, return_value
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index 07ea39a8a3..d4a4adb50c 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -14,26 +14,35 @@
import logging
from http import HTTPStatus
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- async def on_POST(self, request, room_id, event_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReportEventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c5c54564be..9b0c546505 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,11 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
+from twisted.web.server import Request
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -30,6 +32,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
- def on_PUT(self, request, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
info, _ = await self._room_creation_handler.create_room(
@@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
return 200, info
- def get_room_config(self, request):
+ def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__,
)
- def on_GET_no_state_key(self, request, room_id, event_type):
+ def on_GET_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_GET(request, room_id, event_type, "")
- def on_PUT_no_state_key(self, request, room_id, event_type):
+ def on_PUT_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
- async def on_GET(self, request, room_id, event_type, state_key):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
@@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ # Format must be event or content, per the parse_string call above.
+ raise RuntimeError(f"Unknown format: {format:r}.")
+
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if txn_id:
@@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- async def on_POST(self, request, room_id, event_type, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- event_dict = {
+ event_dict: JsonDict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
+ # Twisted will have processed the args by now.
+ assert request.args is not None
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
@@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_GET(self, request, room_id, event_type, txn_id):
+ def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Tuple[int, str]:
return 200, "Not implemented"
- def on_PUT(self, request, room_id, event_type, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id]
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
- def on_PUT(self, request, room_identifier, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_identifier: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
try:
@@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
return 200, data
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server")
@@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
# TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
@@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
users_with_profile = await self.message_handler.get_joined_members(
@@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
+ # Twisted will have processed the args by now.
+ assert request.args is not None
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
@@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
events = await self.message_handler.get_state_events(
@@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
@@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
event = await self.event_handler.get_event(
@@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ event_dict = await self._event_serializer.serialize_event(event, time_now)
+ return 200, event_dict
- return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet):
@@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
- def on_PUT(self, request, room_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/[invite|join|leave]
PATTERNS = (
"/rooms/(?P<room_id>[^/]*)/"
@@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ membership_action: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
@@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
- def _has_3pid_invite_keys(self, content):
+ def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
- def on_PUT(self, request, room_id, membership_action, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, event_id, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_id: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_PUT(self, request, room_id, event_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
hs.config.worker.writers.typing == hs.get_instance_name()
)
- async def on_PUT(self, request, room_id, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer:
@@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
alias_list = await self.directory_handler.get_aliases_for_room(
@@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
-def register_txn_path(servlet, regex_string, http_server, with_get=False):
+def register_txn_path(
+ servlet: RestServlet,
+ regex_string: str,
+ http_server: HttpServer,
+ with_get: bool = False,
+) -> None:
"""Registers a transaction-based path.
This registers two paths:
@@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
POST regex_string
Args:
- regex_string (str): The regex string to register. Must NOT have a
- trailing $ as this string will be appended to.
- http_server : The http_server to register paths with.
+ regex_string: The regex string to register. Must NOT have a
+ trailing $ as this string will be appended to.
+ http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
+ on_POST = getattr(servlet, "on_POST", None)
+ on_PUT = getattr(servlet, "on_PUT", None)
+ if on_POST is None or on_PUT is None:
+ raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
- servlet.on_POST,
+ on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_PUT,
+ on_PUT,
servlet.__class__.__name__,
)
+ on_GET = getattr(servlet, "on_GET", None)
if with_get:
+ if on_GET is None:
+ raise RuntimeError(
+ "register_txn_path called with with_get = True, but no on_GET method exists"
+ )
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_GET,
+ on_GET,
servlet.__class__.__name__,
)
@@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
-def register_servlets(hs: "HomeServer", http_server, is_worker=False):
+def register_servlets(
+ hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
+) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomForgetRestServlet(hs).register(http_server)
-def register_deprecated_servlets(hs, http_server):
+def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomInitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 3172aba605..ed96978448 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,10 +14,14 @@
import logging
import re
+from typing import TYPE_CHECKING, Awaitable, List, Tuple
+
+from twisted.web.server import Request
from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.appservice import ApplicationService
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -25,10 +29,14 @@ from synapse.http.servlet import (
parse_string,
parse_strings_from_args,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
@@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
- async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+ async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
(
most_recent_prev_event_id,
most_recent_prev_event_depth,
@@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
def _create_insertion_event_dict(
self, sender: str, room_id: str, origin_server_ts: int
- ):
+ ) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
and a random chunk ID.
@@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
origin_server_ts: Timestamp when the event was sent
Returns:
- Tuple of event ID and stream ordering position
+ The new event dictionary to insert.
"""
next_chunk_id = random_string(8)
@@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet):
return create_requester(user_id, app_service=app_service)
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
if not requester.app_service:
@@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["state_events_at_start", "events"])
+ assert request.args is not None
prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
chunk_id_from_query = parse_string(request, "chunk_id")
@@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
],
}
- def on_GET(self, request, room_id):
+ def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
return 501, "Not implemented"
- def on_PUT(self, request, room_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
msc2716_enabled = hs.config.experimental.msc2716_enabled
if msc2716_enabled:
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 263596be86..37e39570f6 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,16 +13,23 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet):
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- async def on_PUT(self, request, room_id, session_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional)
@@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet):
ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
return 200, ret
- async def on_GET(self, request, room_id, session_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API.
@@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys
- async def on_DELETE(self, request, room_id, session_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Deletes one or more encrypted E2E room keys for a user for backup purposes.
@@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user
@@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- async def on_GET(self, request, version):
+ async def on_GET(
+ self, request: SynapseRequest, version: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the
@@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet):
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info
- async def on_DELETE(self, request, version):
+ async def on_DELETE(
+ self, request: SynapseRequest, version: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most
@@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet):
await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {}
- async def on_PUT(self, request, version):
+ async def on_PUT(
+ self, request: SynapseRequest, version: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Update the information about a given version of the user's room_keys backup.
@@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomKeysServlet(hs).register(http_server)
RoomKeysVersionServlet(hs).register(http_server)
RoomKeysNewVersionServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index d537d811d8..3322c8ef48 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -13,15 +13,21 @@
# limitations under the License.
import logging
-from typing import Tuple
+from typing import TYPE_CHECKING, Awaitable, Tuple
from synapse.http import servlet
+from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.device_message_handler = hs.get_device_message_handler()
@trace(opname="sendToDevice")
- def on_PUT(self, request, message_type, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, message_type: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("message_type", message_type)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
- async def _put(self, request, message_type, txn_id):
+ async def _put(
+ self, request: SynapseRequest, message_type: str, txn_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"]
)
- response: Tuple[int, dict] = (200, {})
- return response
+ return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SendToDeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 65c37be3e9..1259058b9b 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,12 +14,24 @@
import itertools
import logging
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.api.presence import UserPresenceState
+from synapse.events import EventBase
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
@@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet):
The room, encoded in our response format
"""
- def serialize(events):
+ def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
return self._event_serializer.serialize_events(
events,
time_now=time_now,
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 94ff3719ce..914fb3acf5 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -15,28 +15,37 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
+
+from twisted.python.failure import Failure
+from twisted.web.server import Request
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import JsonDict
from synapse.util.async_helpers import ObservableDeferred
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
- self.transactions = {
- # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
- }
+ # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
+ self.transactions: Dict[
+ str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
+ ] = {}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
- def _get_transaction_key(self, request):
+ def _get_transaction_key(self, request: Request) -> str:
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
@@ -45,15 +54,21 @@ class HttpTransactionCache:
path and the access_token for the requesting user.
Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
+ request: The incoming request. Must contain an access_token.
Returns:
- str: A transaction key
+ A transaction key
"""
+ assert request.path is not None
token = self.auth.get_access_token_from_request(request)
return request.path.decode("utf8") + "/" + token
- def fetch_or_execute_request(self, request, fn, *args, **kwargs):
+ def fetch_or_execute_request(
+ self,
+ request: Request,
+ fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
+ *args: Any,
+ **kwargs: Any,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -64,15 +79,20 @@ class HttpTransactionCache:
self._get_transaction_key(request), fn, *args, **kwargs
)
- def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
+ def fetch_or_execute(
+ self,
+ txn_key: str,
+ fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
+ *args: Any,
+ **kwargs: Any,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
- txn_key (str): A key to ensure idempotency should fetch_or_execute be
- called again at a later point in time.
- fn (function): A function which returns a tuple of
- (response_code, response_dict).
+ txn_key: A key to ensure idempotency should fetch_or_execute be
+ called again at a later point in time.
+ fn: A function which returns a tuple of (response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
@@ -90,7 +110,7 @@ class HttpTransactionCache:
# if the request fails with an exception, remove it
# from the transaction map. This is done to ensure that we don't
# cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
+ def remove_from_map(err: Failure) -> None:
self.transactions.pop(txn_key, None)
# we deliberately do not propagate the error any further, as we
# expect the observers to have reported it.
@@ -99,7 +119,7 @@ class HttpTransactionCache:
return make_deferred_yieldable(observable.observe())
- def _cleanup(self):
+ def _cleanup(self) -> None:
now = self.clock.time_msec()
for key in list(self.transactions):
ts = self.transactions[key][1]
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
new file mode 100644
index 0000000000..2e6706dbfa
--- /dev/null
+++ b/synapse/rest/media/v1/oembed.py
@@ -0,0 +1,155 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+import attr
+
+from synapse.http.client import SimpleHttpClient
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html: Optional[str]
+ url: Optional[str]
+ title: Optional[str]
+ # Number of seconds to cache the content.
+ cache_age: int
+
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
+
+class OEmbedProvider:
+ """
+ A helper for accessing oEmbed content.
+
+ It can be used to check if a URL should be accessed via oEmbed and for
+ requesting/parsing oEmbed content.
+ """
+
+ def __init__(self, hs: "HomeServer", client: SimpleHttpClient):
+ self._oembed_patterns = {}
+ for oembed_endpoint in hs.config.oembed.oembed_patterns:
+ api_endpoint = oembed_endpoint.api_endpoint
+
+ # Only JSON is supported at the moment. This could be declared in
+ # the formats field. Otherwise, if the endpoint ends in .xml assume
+ # it doesn't support JSON.
+ if (
+ oembed_endpoint.formats is not None
+ and "json" not in oembed_endpoint.formats
+ ) or api_endpoint.endswith(".xml"):
+ logger.info(
+ "Ignoring oEmbed endpoint due to not supporting JSON: %s",
+ api_endpoint,
+ )
+ continue
+
+ # Iterate through each URL pattern and point it to the endpoint.
+ for pattern in oembed_endpoint.url_patterns:
+ self._oembed_patterns[pattern] = api_endpoint
+ self._client = client
+
+ def get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Args:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in self._oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Args:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+
+ # Note that only the JSON format is supported, some endpoints want
+ # this in the URL, others want it as an argument.
+ endpoint = endpoint.replace("{format}", "json")
+
+ result = await self._client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ args={"url": url, "format": "json"},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0f051d4041..f108da05db 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import re
import shutil
import sys
import traceback
-from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
@@ -43,6 +43,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@@ -71,65 +73,44 @@ OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
-# A map of globs to API endpoints.
-_oembed_globs = {
- # Twitter.
- "https://publish.twitter.com/oembed": [
- "https://twitter.com/*/status/*",
- "https://*.twitter.com/*/status/*",
- "https://twitter.com/*/moments/*",
- "https://*.twitter.com/*/moments/*",
- # Include the HTTP versions too.
- "http://twitter.com/*/status/*",
- "http://*.twitter.com/*/status/*",
- "http://twitter.com/*/moments/*",
- "http://*.twitter.com/*/moments/*",
- ],
-}
-# Convert the globs to regular expressions.
-_oembed_patterns = {}
-for endpoint, globs in _oembed_globs.items():
- for glob in globs:
- # Convert the glob into a sane regular expression to match against. The
- # rules followed will be slightly different for the domain portion vs.
- # the rest.
- #
- # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
- # 2. The domain can have globs, but we limit it to characters that can
- # reasonably be a domain part.
- # TODO: This does not attempt to handle Unicode domain names.
- # 3. Other parts allow a glob to be any one, or more, characters.
- results = urlparse.urlparse(glob)
-
- # Ensure the scheme does not have wildcards (and is a sane scheme).
- if results.scheme not in {"http", "https"}:
- raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
-
- pattern = urlparse.urlunparse(
- [
- results.scheme,
- re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
- ]
- + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
- )
- _oembed_patterns[re.compile(pattern)] = endpoint
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class MediaInfo:
+ """
+ Information parsed from downloading media being previewed.
+ """
-@attr.s(slots=True)
-class OEmbedResult:
- # Either HTML content or URL must be provided.
- html = attr.ib(type=Optional[str])
- url = attr.ib(type=Optional[str])
- title = attr.ib(type=Optional[str])
- # Number of seconds to cache the content.
- cache_age = attr.ib(type=int)
+ # The Content-Type header of the response.
+ media_type: str
+ # The length (in bytes) of the downloaded media.
+ media_length: int
+ # The media filename, according to the server. This is parsed from the
+ # returned headers, if possible.
+ download_name: Optional[str]
+ # The time of the preview.
+ created_ts_ms: int
+ # Information from the media storage provider about where the file is stored
+ # on disk.
+ filesystem_id: str
+ filename: str
+ # The URI being previewed.
+ uri: str
+ # The HTTP response code.
+ response_code: int
+ # The timestamp (in milliseconds) of when this preview expires.
+ expires: int
+ # The ETag header of the response.
+ etag: Optional[str]
-class OEmbedError(Exception):
- """An error occurred processing the oEmbed object."""
+class PreviewUrlResource(DirectServeJsonResource):
+ """
+ Generating URL previews is a complicated task which many potential pitfalls.
+ See docs/development/url_previews.md for discussion of the design and
+ algorithm followed in this module.
+ """
-class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(
@@ -157,6 +138,8 @@ class PreviewUrlResource(DirectServeJsonResource):
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
+ self._oembed = OEmbedProvider(hs, self.client)
+
# We run the background jobs if we're the instance specified (or no
# instance is specified, where we assume there is only one instance
# serving media).
@@ -275,18 +258,17 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("got media_info of '%s'", media_info)
- if _is_media(media_info["media_type"]):
- file_id = media_info["filesystem_id"]
+ if _is_media(media_info.media_type):
+ file_id = media_info.filesystem_id
dims = await self.media_repo._generate_thumbnails(
- None, file_id, file_id, media_info["media_type"], url_cache=True
+ None, file_id, file_id, media_info.media_type, url_cache=True
)
og = {
- "og:description": media_info["download_name"],
- "og:image": "mxc://%s/%s"
- % (self.server_name, media_info["filesystem_id"]),
- "og:image:type": media_info["media_type"],
- "matrix:image:size": media_info["media_length"],
+ "og:description": media_info.download_name,
+ "og:image": f"mxc://{self.server_name}/{media_info.filesystem_id}",
+ "og:image:type": media_info.media_type,
+ "matrix:image:size": media_info.media_length,
}
if dims:
@@ -296,14 +278,14 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Couldn't get dims for %s" % url)
# define our OG response for this media
- elif _is_html(media_info["media_type"]):
+ elif _is_html(media_info.media_type):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
- with open(media_info["filename"], "rb") as file:
+ with open(media_info.filename, "rb") as file:
body = file.read()
- encoding = get_html_media_encoding(body, media_info["media_type"])
- og = decode_and_calc_og(body, media_info["uri"], encoding)
+ encoding = get_html_media_encoding(body, media_info.media_type)
+ og = decode_and_calc_og(body, media_info.uri, encoding)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
@@ -311,14 +293,14 @@ class PreviewUrlResource(DirectServeJsonResource):
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
image_info = await self._download_url(
- _rebase_url(og["og:image"], media_info["uri"]), user
+ _rebase_url(og["og:image"], media_info.uri), user
)
- if _is_media(image_info["media_type"]):
+ if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images
- file_id = image_info["filesystem_id"]
+ file_id = image_info.filesystem_id
dims = await self.media_repo._generate_thumbnails(
- None, file_id, file_id, image_info["media_type"], url_cache=True
+ None, file_id, file_id, image_info.media_type, url_cache=True
)
if dims:
og["og:image:width"] = dims["width"]
@@ -326,12 +308,11 @@ class PreviewUrlResource(DirectServeJsonResource):
else:
logger.warning("Couldn't get dims for %s", og["og:image"])
- og["og:image"] = "mxc://%s/%s" % (
- self.server_name,
- image_info["filesystem_id"],
- )
- og["og:image:type"] = image_info["media_type"]
- og["matrix:image:size"] = image_info["media_length"]
+ og[
+ "og:image"
+ ] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
+ og["og:image:type"] = image_info.media_type
+ og["matrix:image:size"] = image_info.media_length
else:
del og["og:image"]
else:
@@ -357,98 +338,17 @@ class PreviewUrlResource(DirectServeJsonResource):
# store OG in history-aware DB cache
await self.store.store_url_cache(
url,
- media_info["response_code"],
- media_info["etag"],
- media_info["expires"] + media_info["created_ts"],
+ media_info.response_code,
+ media_info.etag,
+ media_info.expires + media_info.created_ts_ms,
jsonog,
- media_info["filesystem_id"],
- media_info["created_ts"],
+ media_info.filesystem_id,
+ media_info.created_ts_ms,
)
return jsonog.encode("utf8")
- def _get_oembed_url(self, url: str) -> Optional[str]:
- """
- Check whether the URL should be downloaded as oEmbed content instead.
-
- Args:
- url: The URL to check.
-
- Returns:
- A URL to use instead or None if the original URL should be used.
- """
- for url_pattern, endpoint in _oembed_patterns.items():
- if url_pattern.fullmatch(url):
- return endpoint
-
- # No match.
- return None
-
- async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
- """
- Request content from an oEmbed endpoint.
-
- Args:
- endpoint: The oEmbed API endpoint.
- url: The URL to pass to the API.
-
- Returns:
- An object representing the metadata returned.
-
- Raises:
- OEmbedError if fetching or parsing of the oEmbed information fails.
- """
- try:
- logger.debug("Trying to get oEmbed content for url '%s'", url)
- result = await self.client.get_json(
- endpoint,
- # TODO Specify max height / width.
- # Note that only the JSON format is supported.
- args={"url": url},
- )
-
- # Ensure there's a version of 1.0.
- if result.get("version") != "1.0":
- raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
- oembed_type = result.get("type")
-
- # Ensure the cache age is None or an int.
- cache_age = result.get("cache_age")
- if cache_age:
- cache_age = int(cache_age)
-
- oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
-
- # HTML content.
- if oembed_type == "rich":
- oembed_result.html = result.get("html")
- return oembed_result
-
- if oembed_type == "photo":
- oembed_result.url = result.get("url")
- return oembed_result
-
- # TODO Handle link and video types.
-
- if "thumbnail_url" in result:
- oembed_result.url = result.get("thumbnail_url")
- return oembed_result
-
- raise OEmbedError("Incompatible oEmbed information.")
-
- except OEmbedError as e:
- # Trap OEmbedErrors first so we can directly re-raise them.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- raise
-
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
- raise OEmbedError() from e
-
- async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
+ async def _download_url(self, url: str, user: str) -> MediaInfo:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -459,11 +359,11 @@ class PreviewUrlResource(DirectServeJsonResource):
# If this URL can be accessed via oEmbed, use that instead.
url_to_download: Optional[str] = url
- oembed_url = self._get_oembed_url(url)
+ oembed_url = self._oembed.get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
try:
- oembed_result = await self._get_oembed_content(oembed_url, url)
+ oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
if oembed_result.url:
url_to_download = oembed_result.url
elif oembed_result.html:
@@ -560,18 +460,18 @@ class PreviewUrlResource(DirectServeJsonResource):
# therefore not expire it.
raise
- return {
- "media_type": media_type,
- "media_length": length,
- "download_name": download_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- "filename": fname,
- "uri": uri,
- "response_code": code,
- "expires": expires,
- "etag": etag,
- }
+ return MediaInfo(
+ media_type=media_type,
+ media_length=length,
+ download_name=download_name,
+ created_ts_ms=time_now_ms,
+ filesystem_id=file_id,
+ filename=fname,
+ uri=uri,
+ response_code=code,
+ expires=expires,
+ etag=etag,
+ )
def _start_expire_url_cache_data(self):
return run_as_background_process(
@@ -717,7 +617,7 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str:
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
-) -> Dict[str, Optional[str]]:
+) -> JsonDict:
"""
Calculate metadata for an HTML document.
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index fc62a09b7f..3869d18003 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -63,8 +63,8 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
return
- # It should be impossible to get here without having first been through
- # the pick-a-username step, which ensures chosen_localpart gets set.
+ # It should be impossible to get here without either the user or the mapping provider
+ # having chosen a username, which ensures chosen_localpart gets set.
if not session.chosen_localpart:
logger.warning("Session has no user name selected")
self._sso_handler.render_error(
|