diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py
index 3ebe401861..6c24b96c54 100644
--- a/synapse/rest/client/account_validity.py
+++ b/synapse/rest/client/account_validity.py
@@ -13,24 +13,27 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import SynapseError
-from synapse.http.server import respond_with_html
-from synapse.http.servlet import RestServlet
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer, respond_with_html
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
hs.config.account_validity.account_validity_invalid_token_template
)
- async def on_GET(self, request):
- if b"token" not in request.args:
- raise SynapseError(400, "Missing renewal token")
- renewal_token = request.args[b"token"][0]
+ async def on_GET(self, request: Request) -> None:
+ renewal_token = parse_string(request, "token", required=True)
(
token_valid,
token_stale,
expiration_ts,
- ) = await self.account_activity_handler.renew_account(
- renewal_token.decode("utf8")
- )
+ ) = await self.account_activity_handler.renew_account(renewal_token)
if token_valid:
status_code = 200
@@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
hs.config.account_validity.account_validity_renew_by_email_enabled
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)
@@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 6ea1b50a62..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -16,7 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError
+from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
+ self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ )
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -95,29 +102,32 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.recaptcha_public_key,
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.TERMS, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
@@ -127,10 +137,33 @@ class AuthRestServlet(RestServlet):
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ token = parse_string(request, "token", required=True)
+ authdict = {"session": session, "token": token}
+
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ error=e.msg,
+ )
+ else:
+ html = self.success_template.render()
+
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 88e3aac797..65b3b5ce2c 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
@@ -61,8 +62,19 @@ class CapabilitiesRestServlet(RestServlet):
"org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES
+ if self.config.experimental.msc3283_enabled:
+ response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
+ "enabled": self.config.enable_set_displayname
+ }
+ response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
+ "enabled": self.config.enable_set_avatar_url
+ }
+ response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
+ "enabled": self.config.enable_3pid_changes
+ }
+
return 200, response
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index ffa075c8e5..ee247e3d1e 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.errors import (
AuthError,
@@ -22,14 +24,19 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
@@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
- res = await self.directory_handler.get_association(room_alias)
+ res = await self.directory_handler.get_association(room_alias_obj)
return 200, res
- async def on_PUT(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_PUT(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
@@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
)
logger.debug("Got content: %s", content)
- logger.debug("Got room name: %s", room_alias.to_string())
+ logger.debug("Got room name: %s", room_alias_obj.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
@@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
- requester, room_alias, room_id, servers
+ requester, room_alias_obj, room_id, servers
)
return 200, {}
- async def on_DELETE(self, request, room_alias):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
+
try:
service = self.auth.get_appservice_by_req(request)
- room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association(
- service, room_alias
+ service, room_alias_obj
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
- room_alias.to_string(),
+ room_alias_obj.to_string(),
)
return 200, {}
except InvalidClientCredentialsError:
@@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
user = requester.user
- room_alias = RoomAlias.from_string(room_alias)
-
- await self.directory_handler.delete_association(requester, room_alias)
+ await self.directory_handler.delete_association(requester, room_alias_obj)
logger.info(
- "User %s deleted alias %s", user.to_string(), room_alias.to_string()
+ "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
)
return 200, {}
@@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
- async def on_PUT(self, request, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
return 200, {}
- async def on_DELETE(self, request, room_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.edit_published_room_list(
@@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- def on_PUT(self, request, network_id, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- return self._edit(request, network_id, room_id, visibility)
+ return await self._edit(request, network_id, room_id, visibility)
- def on_DELETE(self, request, network_id, room_id):
- return self._edit(request, network_id, room_id, "private")
+ async def on_DELETE(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ return await self._edit(request, network_id, room_id, "private")
- async def _edit(self, request, network_id, room_id, visibility):
+ async def _edit(
+ self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index d0d9d30d40..012491f597 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,8 +15,9 @@
# limitations under the License.
import logging
+from typing import Any
-from synapse.api.errors import SynapseError
+from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet):
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ device_keys = body.get("device_keys")
+ if not isinstance(device_keys, dict):
+ raise InvalidAPICallError("'device_keys' must be a JSON object")
+
+ def is_list_of_strings(values: Any) -> bool:
+ return isinstance(values, list) and all(isinstance(v, str) for v in values)
+
+ if any(not is_list_of_strings(keys) for keys in device_keys.values()):
+ raise InvalidAPICallError(
+ "'device_keys' values must be a list of strings",
+ )
+
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0c8d8967b7..11d07776b2 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -104,6 +104,12 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_account.burst_count,
)
+ # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
+ # The reason for this is to ensure that the auth_provider_ids are registered
+ # with SsoHandler, which in turn ensures that the login/registration prometheus
+ # counters are initialised for the auth_provider_ids.
+ _load_sso_handlers(hs)
+
def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
@@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
- if hs.config.cas_enabled:
- hs.get_cas_handler()
- if hs.config.saml2_enabled:
- hs.get_saml_handler()
- if hs.config.oidc_enabled:
- hs.get_oidc_handler()
+ _load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
@@ -598,3 +599,19 @@ def register_servlets(hs, http_server):
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
+
+
+def _load_sso_handlers(hs: "HomeServer"):
+ """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
+
+ This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
+ with the main SsoHandler.
+
+ It's safe to call this multiple times.
+ """
+ if hs.config.cas.cas_enabled:
+ hs.get_cas_handler()
+ if hs.config.saml2.saml2_enabled:
+ hs.get_saml_handler()
+ if hs.config.oidc.oidc_enabled:
+ hs.get_oidc_handler()
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 84619c5e41..98604a9388 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import respond_with_html_bytes
+from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
@@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 027f8b81fa..43c04fac6f 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -13,27 +13,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
return 200, {"available": True}
+class RegistrationTokenValidityRestServlet(RestServlet):
+ """Check the validity of a registration token.
+
+ Example:
+
+ GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+ 200 OK
+
+ {
+ "valid": true
+ }
+ """
+
+ PATTERNS = client_patterns(
+ f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+ releases=(),
+ unstable=True,
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.ratelimiter = Ratelimiter(
+ store=self.store,
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+ burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+ )
+
+ async def on_GET(self, request):
+ await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
+ token = parse_string(request, "token", required=True)
+ valid = await self.store.registration_token_is_valid(token)
+
+ return 200, {"valid": valid}
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
)
if registered:
+ # Check if a token was used to authenticate registration
+ registration_token = await self.auth_handler.get_session_data(
+ session_id,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ )
+ if registration_token:
+ # Increment the `completed` counter for the token
+ await self.store.use_registration_token(registration_token)
+ # Indicate that the token has been successfully used so that
+ # pending is not decremented again when expiring old UIA sessions.
+ await self.store.mark_ui_auth_stage_complete(
+ session_id,
+ LoginType.REGISTRATION_TOKEN,
+ True,
+ )
+
await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
+ # Prepend registration token to all flows if we're requiring a token
+ if config.registration_requires_token:
+ for flow in flows:
+ flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
return flows
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
+ RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6d1b083acb..6a7792e18b 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -13,18 +13,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util import stringutils
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
}
Creates a new room and shuts down the old one. Returns the ID of the new room.
-
- Args:
- hs (synapse.server.HomeServer):
"""
PATTERNS = client_patterns(
@@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/upgrade$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomUpgradeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index d2e7f04b40..1d90493eb0 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
releases=(), # This is an unstable feature
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
if not self.user_directory_active:
raise SynapseError(
@@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
return 200, {"joined": list(rooms)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e18f4d01b3..65c37be3e9 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,17 +14,26 @@
import itertools
import logging
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.api.presence import UserPresenceState
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import KnockedSyncResult, SyncConfig
+from synapse.handlers.sync import (
+ ArchivedSyncResult,
+ InvitedSyncResult,
+ JoinedSyncResult,
+ KnockedSyncResult,
+ SyncConfig,
+ SyncResult,
+)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
return 200, {}
time_now = self.clock.time_msec()
+ # We know that the the requester has an access token since appservices
+ # cannot use sync.
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
@@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
- async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ async def encode_response(
+ self,
+ time_now: int,
+ sync_result: SyncResult,
+ access_token_id: Optional[int],
+ filter: FilterCollection,
+ ) -> JsonDict:
logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
@@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
logger.debug("building sync response dict")
- response: dict = defaultdict(dict)
+ response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.account_data:
@@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
if archived:
response["rooms"][Membership.LEAVE] = archived
+ # By the time we get here groups is no longer optional.
+ assert sync_result.groups is not None
if sync_result.groups.join:
response["groups"][Membership.JOIN] = sync_result.groups.join
if sync_result.groups.invite:
@@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
return response
@staticmethod
- def encode_presence(events, time_now):
+ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
return {
"events": [
{
@@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
}
async def encode_joined(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[JoinedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the joined rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
- results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list<str>): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the joined rooms list, in our
- response format
+ The joined rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
return joined
- async def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(
+ self,
+ rooms: List[InvitedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the invited rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is invited to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is invited to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the invited rooms list, in our
- response format
+ The invited rooms list, in our response format
"""
invited = {}
for room in rooms:
@@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[KnockedSyncResult],
time_now: int,
- token_id: int,
+ token_id: Optional[int],
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
@@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
return knocked
async def encode_archived(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the archived rooms in a sync result
Args:
- rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
- sync results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list<str>): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
- to client format
+ event_formatter: function to convert from federation format to client format
Returns:
- dict[str, dict[str, object]]: The invited rooms list, in our
- response format
+ The archived rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
return joined
async def encode_room(
- self, room, time_now, token_id, joined, only_fields, event_formatter
- ):
+ self,
+ room: Union[JoinedSyncResult, ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ joined: bool,
+ only_fields: Optional[List[str]],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Args:
- room (JoinedSyncResult|ArchivedSyncResult): sync result for a
- single room
- time_now (int): current time - used as a baseline for age
- calculations
- token_id (int): ID of the user's auth token - used for namespacing
+ room: sync result for a single room
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- joined (bool): True if the user is joined to this room - will mean
+ joined: True if the user is joined to this room - will mean
we handle ephemeral events
- only_fields(list<str>): Optional. The list of event fields to include.
- event_formatter (func[dict]): function to convert from federation format
+ only_fields: Optional. The list of event fields to include.
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, object]: the room, encoded in our response format
+ The room, encoded in our response format
"""
def serialize(events):
@@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
account_data = room.account_data
- result = {
+ result: JsonDict = {
"timeline": {
"events": serialized_timeline,
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
@@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
}
if joined:
+ assert isinstance(room, JoinedSyncResult)
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
@@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
return result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index c14f83be18..c88cb9367c 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
@@ -54,12 +63,14 @@ class TagServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, tag):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -70,7 +81,9 @@ class TagServlet(RestServlet):
return 200, {}
- async def on_DELETE(self, request, user_id, room_id, tag):
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -80,6 +93,6 @@ class TagServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)
diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py
index b5c67c9bb6..b895c73acf 100644
--- a/synapse/rest/client/thirdparty.py
+++ b/synapse/rest/client/thirdparty.py
@@ -12,27 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.constants import ThirdPartyEntityKind
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols()
@@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols(
@@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py
index b2f858545c..c8c3b25bd3 100644
--- a/synapse/rest/client/tokenrefresh.py
+++ b/synapse/rest/client/tokenrefresh.py
@@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class TokenRefreshRestServlet(RestServlet):
"""
@@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> None:
raise AuthError(403, "tokenrefresh is no longer supported.")
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 7e8912f0b9..8852811114 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -13,29 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Searches for users in directory
Returns:
@@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index fa2e4e9cba..a1a815cf82 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -17,9 +17,17 @@
import logging
import re
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
in self.config.encryption_enabled_by_default_for_room_presets
)
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
return (
200,
{
@@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VersionsRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index f53020520d..9d46ed3af3 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -15,20 +15,27 @@
import base64
import hashlib
import hmac
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
@@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VoipRestServlet(hs).register(http_server)
|