diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 14eca70ba4..2e81eeff65 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -120,6 +120,7 @@ class ClientRestResource(JsonResource):
account_validity.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 42cc2b062a..6b85148a32 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -29,7 +29,12 @@ from synapse.rest.admin._base import (
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
-from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
+from synapse.rest.admin.rooms import (
+ JoinRoomAliasServlet,
+ ListRoomRestServlet,
+ RoomRestServlet,
+ ShutdownRoomRestServlet,
+)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
@@ -189,6 +194,8 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
+ RoomRestServlet(hs).register(http_server)
+ JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f9b8c0a4f0..8173baef8f 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Optional
-from synapse.api.constants import Membership
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -25,11 +26,12 @@ from synapse.http.servlet import (
)
from synapse.rest.admin._base import (
admin_patterns,
+ assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
)
from synapse.storage.data_stores.main.room import RoomSortOrder
-from synapse.types import create_requester
+from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -57,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet):
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
+ self._replication = hs.get_replication_data_handler()
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@@ -71,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = await self._room_creation_handler.create_room(
+ info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@@ -92,6 +95,15 @@ class ShutdownRoomRestServlet(RestServlet):
# desirable in case the first attempt at blocking the room failed below.
await self.store.block_room(room_id, requester_user_id)
+ # We now wait for the create room to come back in via replication so
+ # that we can assume that all the joins/invites have propogated before
+ # we try and auto join below.
+ #
+ # TODO: Currently the events stream is written to from master
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
users = await self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
@@ -103,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet):
try:
target_requester = create_requester(user_id)
- await self.room_member_handler.update_membership(
+ _, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@@ -113,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False,
)
+ # Wait for leave to come in over replication before trying to forget.
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
await self.room_member_handler.forget(target_requester.user, room_id)
await self.room_member_handler.update_membership(
@@ -168,7 +185,7 @@ class ListRoomRestServlet(RestServlet):
in a dictionary containing room information. Supports pagination.
"""
- PATTERNS = admin_patterns("/rooms")
+ PATTERNS = admin_patterns("/rooms$")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -182,10 +199,23 @@ class ListRoomRestServlet(RestServlet):
# Extract query parameters
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
- order_by = parse_string(request, "order_by", default="alphabetical")
+ order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
if order_by not in (
RoomSortOrder.ALPHABETICAL.value,
RoomSortOrder.SIZE.value,
+ RoomSortOrder.NAME.value,
+ RoomSortOrder.CANONICAL_ALIAS.value,
+ RoomSortOrder.JOINED_MEMBERS.value,
+ RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
+ RoomSortOrder.VERSION.value,
+ RoomSortOrder.CREATOR.value,
+ RoomSortOrder.ENCRYPTION.value,
+ RoomSortOrder.FEDERATABLE.value,
+ RoomSortOrder.PUBLIC.value,
+ RoomSortOrder.JOIN_RULES.value,
+ RoomSortOrder.GUEST_ACCESS.value,
+ RoomSortOrder.HISTORY_VISIBILITY.value,
+ RoomSortOrder.STATE_EVENTS.value,
):
raise SynapseError(
400,
@@ -237,3 +267,98 @@ class ListRoomRestServlet(RestServlet):
response["prev_batch"] = 0
return 200, response
+
+
+class RoomRestServlet(RestServlet):
+ """Get room details.
+
+ TODO: Add on_POST to allow room creation without joining the room
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, room_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ ret = await self.store.get_room_with_stats(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ return 200, ret
+
+
+class JoinRoomAliasServlet(RestServlet):
+
+ PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.admin_handler = hs.get_handlers().admin_handler
+ self.state_handler = hs.get_state_handler()
+
+ async def on_POST(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+
+ assert_params_in_dict(content, ["user_id"])
+ target_user = UserID.from_string(content["user_id"])
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "This endpoint can only be used with local users")
+
+ if not await self.admin_handler.get_user(target_user):
+ raise NotFoundError("User not found")
+
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ try:
+ remote_room_hosts = [
+ x.decode("ascii") for x in request.args[b"server_name"]
+ ] # type: Optional[List[str]]
+ except Exception:
+ remote_room_hosts = None
+ elif RoomAlias.is_valid(room_identifier):
+ handler = self.room_member_handler
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ fake_requester = create_requester(target_user)
+
+ # send invite if room has "JoinRules.INVITE"
+ room_state = await self.state_handler.get_current_state(room_id)
+ join_rules_event = room_state.get((EventTypes.JoinRules, ""))
+ if join_rules_event:
+ if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=fake_requester.user,
+ room_id=room_id,
+ action="invite",
+ remote_room_hosts=remote_room_hosts,
+ ratelimit=False,
+ )
+
+ await self.room_member_handler.update_membership(
+ requester=fake_requester,
+ target=fake_requester.user,
+ room_id=room_id,
+ action="join",
+ remote_room_hosts=remote_room_hosts,
+ ratelimit=False,
+ )
+
+ return 200, {"room_id": room_id}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 8551ac19b8..e7f6928c85 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -94,10 +94,10 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
- users = await self.store.get_users_paginate(
+ users, total = await self.store.get_users_paginate(
start, limit, user_id, guests, deactivated
)
- ret = {"users": users}
+ ret = {"users": users, "total": total}
if len(users) >= limit:
ret["next_token"] = str(start + len(users))
@@ -199,7 +199,7 @@ class UserRestServletV2(RestServlet):
user_id, threepid["medium"], threepid["address"], current_time
)
- if "avatar_url" in body:
+ if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
target_user, requester, body["avatar_url"], True
)
@@ -222,8 +222,14 @@ class UserRestServletV2(RestServlet):
else:
new_password = body["password"]
logout_devices = True
+
+ new_password_hash = await self.auth_handler.hash(new_password)
+
await self.set_password_handler.set_password(
- target_user.to_string(), new_password, logout_devices, requester
+ target_user.to_string(),
+ new_password_hash,
+ logout_devices,
+ requester,
)
if "deactivated" in body:
@@ -243,11 +249,11 @@ class UserRestServletV2(RestServlet):
else: # create user
password = body.get("password")
- if password is not None and (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
- raise SynapseError(400, "Invalid password")
+ password_hash = None
+ if password is not None:
+ if not isinstance(password, text_type) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ password_hash = await self.auth_handler.hash(password)
admin = body.get("admin", None)
user_type = body.get("user_type", None)
@@ -259,7 +265,7 @@ class UserRestServletV2(RestServlet):
user_id = await self.registration_handler.register_user(
localpart=target_user.localpart,
- password=password,
+ password_hash=password_hash,
admin=bool(admin),
default_display_name=displayname,
user_type=user_type,
@@ -276,7 +282,7 @@ class UserRestServletV2(RestServlet):
user_id, threepid["medium"], threepid["address"], current_time
)
- if "avatar_url" in body:
+ if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
user_id, requester, body["avatar_url"], True
)
@@ -298,7 +304,7 @@ class UserRegisterServlet(RestServlet):
NONCE_TIMEOUT = 60
def __init__(self, hs):
- self.handlers = hs.get_handlers()
+ self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
self.nonces = {}
self.hs = hs
@@ -362,16 +368,16 @@ class UserRegisterServlet(RestServlet):
400, "password must be specified", errcode=Codes.BAD_JSON
)
else:
- if (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
+ password = body["password"]
+ if not isinstance(password, text_type) or len(password) > 512:
raise SynapseError(400, "Invalid password")
- password = body["password"].encode("utf-8")
- if b"\x00" in password:
+ password_bytes = password.encode("utf-8")
+ if b"\x00" in password_bytes:
raise SynapseError(400, "Invalid password")
+ password_hash = await self.auth_handler.hash(password)
+
admin = body.get("admin", None)
user_type = body.get("user_type", None)
@@ -388,7 +394,7 @@ class UserRegisterServlet(RestServlet):
want_mac_builder.update(b"\x00")
want_mac_builder.update(username)
want_mac_builder.update(b"\x00")
- want_mac_builder.update(password)
+ want_mac_builder.update(password_bytes)
want_mac_builder.update(b"\x00")
want_mac_builder.update(b"admin" if admin else b"notadmin")
if user_type:
@@ -407,7 +413,7 @@ class UserRegisterServlet(RestServlet):
user_id = await register.registration_handler.register_user(
localpart=body["username"].lower(),
- password=body["password"],
+ password_hash=password_hash,
admin=bool(admin),
user_type=user_type,
)
@@ -523,6 +529,7 @@ class ResetPasswordRestServlet(RestServlet):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id):
@@ -539,8 +546,10 @@ class ResetPasswordRestServlet(RestServlet):
new_password = params["new_password"]
logout_devices = params.get("logout_devices", True)
+ new_password_hash = await self.auth_handler.hash(new_password)
+
await self._set_password_handler.set_password(
- target_user_id, new_password, logout_devices, requester
+ target_user_id, new_password_hash, logout_devices, requester
)
return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d0d4999795..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,11 +14,6 @@
# limitations under the License.
import logging
-import xml.etree.ElementTree as ET
-
-from six.moves import urllib
-
-from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
+from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@@ -88,6 +83,7 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
+ self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -101,9 +97,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -119,6 +113,11 @@ class LoginRestServlet(RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.oidc_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -402,24 +401,27 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
- sso_url = self.get_sso_url(client_redirect_url)
+ sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url):
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
- client_redirect_url (bytes): the URL that we should redirect the
+ request: The client request to redirect.
+ client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
- bytes: URL to redirect to
+ URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@@ -427,19 +429,14 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
- super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url.encode("ascii")
- self.cas_service_url = hs.config.cas_service_url.encode("ascii")
+ self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url):
- client_redirect_url_param = urllib.parse.urlencode(
- {b"redirectUrl": client_redirect_url}
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return self._cas_handler.get_redirect_url(
+ {"redirectUrl": client_redirect_url}
).encode("ascii")
- hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
- service_param = urllib.parse.urlencode(
- {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
- ).encode("ascii")
- return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@@ -447,81 +444,25 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
- self.cas_displayname_attribute = hs.config.cas_displayname_attribute
- self.cas_required_attributes = hs.config.cas_required_attributes
- self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_proxied_http_client()
-
- async def on_GET(self, request):
- client_redirect_url = parse_string(request, "redirectUrl", required=True)
- uri = self.cas_server_url + "/proxyValidate"
- args = {
- "ticket": parse_string(request, "ticket", required=True),
- "service": self.cas_service_url,
- }
- try:
- body = await self._http_client.get_raw(uri, args)
- except PartialDownloadError as pde:
- # Twisted raises this error if the connection is closed,
- # even if that's being used old-http style to signal end-of-data
- body = pde.response
- result = await self.handle_cas_response(request, body, client_redirect_url)
- return result
+ self._cas_handler = hs.get_cas_handler()
- def handle_cas_response(self, request, cas_response_body, client_redirect_url):
- user, attributes = self.parse_cas_response(cas_response_body)
- displayname = attributes.pop(self.cas_displayname_attribute, None)
+ async def on_GET(self, request: SynapseRequest) -> None:
+ client_redirect_url = parse_string(request, "redirectUrl")
+ ticket = parse_string(request, "ticket", required=True)
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ # Maybe get a session ID (if this ticket is from user interactive
+ # authentication).
+ session = parse_string(request, "session")
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ # Either client_redirect_url or session must be provided.
+ if not client_redirect_url and not session:
+ message = "Missing string query parameter redirectUrl or session"
+ raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
- return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url, displayname
+ await self._cas_handler.handle_ticket(
+ request, ticket, client_redirect_url, session
)
- def parse_cas_response(self, cas_response_body):
- user = None
- attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
-
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -529,69 +470,25 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url):
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
-class SSOAuthHandler(object):
- """
- Utility class for Resources and Servlets which handle the response from a SSO
- service
+class OIDCRedirectServlet(BaseSSORedirectServlet):
+ """Implementation for /login/sso/redirect for the OIDC login flow."""
- Args:
- hs (synapse.server.HomeServer)
- """
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
- self._hostname = hs.hostname
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
- self._macaroon_gen = hs.get_macaroon_generator()
-
- # Load the redirect page HTML template
- self._template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
- )[0]
-
- self._server_name = hs.config.server_name
-
- # cast to tuple for use with str.startswith
- self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
-
- async def on_successful_auth(
- self, username, request, client_redirect_url, user_display_name=None
- ):
- """Called once the user has successfully authenticated with the SSO.
-
- Registers the user if necessary, and then returns a redirect (with
- a login token) to the client.
-
- Args:
- username (unicode|bytes): the remote user id. We'll map this onto
- something sane for a MXID localpath.
-
- request (SynapseRequest): the incoming request from the browser. We'll
- respond to it with a redirect.
-
- client_redirect_url (unicode): the redirect_url the client gave us when
- it first started the process.
-
- user_display_name (unicode|None): if set, and we have to register a new user,
- we will set their displayname to this.
-
- Returns:
- Deferred[none]: Completes once we have handled the request.
- """
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
- )
+ self._oidc_handler = hs.get_oidc_handler()
- self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url
)
@@ -602,3 +499,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
+ elif hs.config.oidc_enabled:
+ OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1cf3caf832..b0c30b65be 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -34,10 +34,10 @@ class LogoutRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
- # the acccess token wasn't associated with a device.
+ # The access token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
await self._auth_handler.delete_access_token(access_token)
@@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
# first delete all of the user's devices
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index e788eb0193..43b64608e7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
+ HttpResponseException,
InvalidClientCredentialsError,
SynapseError,
)
@@ -92,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
- info = await self._room_creation_handler.create_room(
+ info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -201,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = await self.room_member_handler.update_membership(
+ event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -209,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
+ event_id = event.event_id
ret = {} # type: dict
- if event:
- set_tag("event_id", event.event_id)
- ret = {"event_id": event.event_id}
+ if event_id:
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -246,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -364,10 +369,13 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None
handler = self.hs.get_room_list_handler()
- if server:
- data = await handler.get_remote_public_room_list(
- server, limit=limit, since_token=since_token
- )
+ if server and server != self.hs.config.server_name:
+ try:
+ data = await handler.get_remote_public_room_list(
+ server, limit=limit, since_token=since_token
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
else:
data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token
@@ -404,15 +412,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None
handler = self.hs.get_room_list_handler()
- if server:
- data = await handler.get_remote_public_room_list(
- server,
- limit=limit,
- since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
+ if server and server != self.hs.config.server_name:
+ try:
+ data = await handler.get_remote_public_room_list(
+ server,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
else:
data = await handler.get_local_public_room_list(
limit=limit,
@@ -775,7 +786,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 7d2cd29a60..8d081718e3 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -34,7 +34,7 @@ from synapse.http.servlet import (
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
-from synapse.util.stringutils import assert_valid_client_secret
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -104,6 +104,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
if existing_user_id is None:
+ if self.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
@@ -219,6 +224,7 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
+ self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
self.http_client = hs.get_simple_http_client()
@@ -226,6 +232,20 @@ class PasswordRestServlet(RestServlet):
async def on_POST(self, request):
body = parse_json_object_from_request(request)
+ # we do basic sanity checks here because the auth layer will store these
+ # in sessions. Pull out the new password provided to us.
+ if "new_password" in body:
+ new_password = body.pop("new_password")
+ if not isinstance(new_password, str) or len(new_password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(new_password)
+
+ # If the password is valid, hash it and store it back on the body.
+ # This ensures that only the hashed password is handled everywhere.
+ if "new_password_hash" in body:
+ raise SynapseError(400, "Unexpected property: new_password_hash")
+ body["new_password_hash"] = await self.auth_handler.hash(new_password)
+
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -243,13 +263,21 @@ class PasswordRestServlet(RestServlet):
params = body
else:
params = await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
)
user_id = requester.user.to_string()
else:
requester = None
result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
+ [[LoginType.EMAIL_IDENTITY]],
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
)
if LoginType.EMAIL_IDENTITY in result:
@@ -272,12 +300,12 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password"])
- new_password = params["new_password"]
+ assert_params_in_dict(params, ["new_password_hash"])
+ new_password_hash = params["new_password_hash"]
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password, logout_devices, requester
+ user_id, new_password_hash, logout_devices, requester
)
if self.hs.config.shadow_server:
@@ -335,7 +363,11 @@ class DeactivateAccountRestServlet(RestServlet):
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
@@ -407,6 +439,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
+ if self.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
@@ -472,6 +509,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
+ if self.hs.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
if not self.hs.config.account_threepid_delegate_msisdn:
@@ -634,8 +676,10 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request):
- if self.hs.config.disable_3pid_changes:
- raise SynapseError(400, "3PID changes disabled on this server")
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -729,6 +773,11 @@ class ThreepidAddRestServlet(RestServlet):
@interactive_auth_handler
async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -739,7 +788,11 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
@@ -852,8 +905,10 @@ class ThreepidDeleteRestServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
- if self.hs.config.disable_3pid_changes:
- raise SynapseError(400, "3PID changes disabled on this server")
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 50e080673b..75590ebaeb 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -130,7 +130,22 @@ class AuthRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- def on_GET(self, request, stagetype):
+ # SSO configuration.
+ self._cas_enabled = hs.config.cas_enabled
+ if self._cas_enabled:
+ self._cas_handler = hs.get_cas_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
+ self._saml_enabled = hs.config.saml2_enabled
+ if self._saml_enabled:
+ self._saml_handler = hs.get_saml_handler()
+ self._oidc_enabled = hs.config.oidc_enabled
+ if self._oidc_enabled:
+ self._oidc_handler = hs.get_oidc_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
+
+ async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -142,14 +157,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
"session": session,
@@ -158,17 +165,50 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
+
+ elif stagetype == LoginType.SSO:
+ # Display a confirmation page which prompts the user to
+ # re-authenticate with their SSO provider.
+ if self._cas_enabled:
+ # Generate a request to CAS that redirects back to an endpoint
+ # to verify the successful authentication.
+ sso_redirect_url = self._cas_handler.get_redirect_url(
+ {"session": session},
+ )
+
+ elif self._saml_enabled:
+ # Some SAML identity providers (e.g. Google) require a
+ # RelayState parameter on requests. It is not necessary here, so
+ # pass in a dummy redirect URL (which will never get used).
+ client_redirect_url = b"unused"
+ sso_redirect_url = self._saml_handler.handle_redirect_request(
+ client_redirect_url, session
+ )
+
+ elif self._oidc_enabled:
+ client_redirect_url = b""
+ sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url, session
+ )
+
+ else:
+ raise SynapseError(400, "Homeserver not configured for SSO.")
+
+ html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+
else:
raise SynapseError(404, "Unknown auth stage type")
+ # Render the HTML and return.
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+ return None
+
async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
@@ -196,15 +236,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
-
- return None
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -225,17 +256,22 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
+ elif stagetype == LoginType.SSO:
+ # The SSO fallback workflow should not post here,
+ raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
else:
raise SynapseError(404, "Unknown auth stage type")
+ # Render the HTML and return.
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+ return None
+
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 94ff73f384..c0714fcfb1 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -81,7 +81,11 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@@ -127,7 +131,11 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index f7ed4daf90..8f41a3edbf 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -263,7 +263,11 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c3c96a9e86..183f9cf5c0 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -51,7 +51,7 @@ from synapse.http.servlet import (
from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
-from synapse.util.stringutils import assert_valid_client_secret
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -137,6 +137,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
+ if self.hs.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
@@ -206,6 +211,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
+ if self.hs.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
@@ -419,15 +429,19 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
- desired_password = None
+ desired_password_hash = None
if "password" in body:
- if (
- not isinstance(body["password"], string_types)
- or len(body["password"]) > 512
- ):
+ password = body.pop("password")
+ if not isinstance(password, string_types) or len(password) > 512:
raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(body["password"])
- desired_password = body["password"]
+ self.password_policy_handler.validate_password(password)
+
+ # If the password is valid, hash it and store it back on the body.
+ # This ensures that only the hashed password is handled everywhere.
+ if "password_hash" in body:
+ raise SynapseError(400, "Unexpected property: password_hash")
+ body["password_hash"] = await self.auth_handler.hash(password)
+ desired_password_hash = body["password_hash"]
desired_username = None
if "username" in body:
@@ -464,7 +478,7 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, string_types):
result = await self._do_appservice_registration(
desired_username,
- desired_password,
+ desired_password_hash,
desired_display_name,
access_token,
body,
@@ -486,7 +500,7 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password" not in body:
+ if "initial_device_display_name" in body and "password_hash" not in body:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -501,7 +515,7 @@ class RegisterRestServlet(RestServlet):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
- registered_user_id = self.auth_handler.get_session_data(
+ registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
@@ -513,7 +527,11 @@ class RegisterRestServlet(RestServlet):
)
auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows, body, self.hs.get_ip_from_request(request)
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
)
# Check that we're not trying to register a denied 3pid.
@@ -618,7 +636,7 @@ class RegisterRestServlet(RestServlet):
registered = False
else:
# NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password"])
+ assert_params_in_dict(params, ["password_hash"])
if not self.hs.config.register_mxid_from_3pid:
desired_username = params.get("username", None)
@@ -627,9 +645,7 @@ class RegisterRestServlet(RestServlet):
pass
guest_access_token = params.get("guest_access_token", None)
-
- # XXX: don't we need to validate these for length etc like we did on
- # the ones from the JSON body earlier on in the method?
+ new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -662,7 +678,7 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password=params.get("password", None),
+ password_hash=new_password_hash,
guest_access_token=guest_access_token,
default_display_name=desired_display_name,
threepid=threepid,
@@ -686,7 +702,7 @@ class RegisterRestServlet(RestServlet):
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
- self.auth_handler.set_session_data(
+ await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -709,12 +725,12 @@ class RegisterRestServlet(RestServlet):
return 200, {}
async def _do_appservice_registration(
- self, username, password, display_name, as_token, body
+ self, username, password_hash, display_name, as_token, body
):
# FIXME: appservice_register() is horribly duplicated with register()
# and they should probably just be combined together with a config flag.
user_id = await self.registration_handler.appservice_register(
- username, as_token, password, display_name
+ username, as_token, password_hash, display_name
)
result = await self._create_registration_details(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 63f07b63da..89002ffbff 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 38952a1d27..59529707df 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
"""
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- version = parse_string(request, "version")
+ version = parse_string(request, "version", required=True)
room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 503f2bed98..3689777266 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,6 @@
import logging
import os
-from six import PY3
from six.moves import urllib
from twisted.internet import defer
@@ -324,23 +323,15 @@ def get_filename_from_headers(headers):
upload_name_utf8 = upload_name_utf8[7:]
# We have a filename*= section. This MUST be ASCII, and any UTF-8
# bytes are %-quoted.
- if PY3:
- try:
- # Once it is decoded, we can then unquote the %-encoded
- # parts strictly into a unicode string.
- upload_name = urllib.parse.unquote(
- upload_name_utf8.decode("ascii"), errors="strict"
- )
- except UnicodeDecodeError:
- # Incorrect UTF-8.
- pass
- else:
- # On Python 2, we first unquote the %-encoded parts and then
- # decode it strictly using UTF-8.
- try:
- upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
- except UnicodeDecodeError:
- pass
+ try:
+ # Once it is decoded, we can then unquote the %-encoded
+ # parts strictly into a unicode string.
+ upload_name = urllib.parse.unquote(
+ upload_name_utf8.decode("ascii"), errors="strict"
+ )
+ except UnicodeDecodeError:
+ # Incorrect UTF-8.
+ pass
# If there isn't check for an ascii name.
if not upload_name:
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 66a01559e1..24d3ae5bbc 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
b" media-src 'self';"
b" object-src 'self';",
)
+ request.setHeader(
+ b"Referrer-Policy", b"no-referrer",
+ )
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 490b1b45a8..fd10d42f2f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -24,7 +24,6 @@ from six import iteritems
import twisted.internet.error
import twisted.web.http
-from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -114,15 +113,14 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
- @defer.inlineCallbacks
- def _update_recently_accessed(self):
+ async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
- yield self.store.update_cached_last_access_time(
+ await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@@ -138,8 +136,7 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
- @defer.inlineCallbacks
- def create_content(
+ async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@@ -158,11 +155,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
- fname = yield self.media_storage.store_file(content, file_info)
+ fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
- yield self.store.store_local_media(
+ await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -171,12 +168,11 @@ class MediaRepository(object):
user_id=auth_user,
)
- yield self._generate_thumbnails(None, media_id, media_id, media_type)
+ await self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
- @defer.inlineCallbacks
- def get_local_media(self, request, media_id, name):
+ async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@@ -190,7 +186,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -204,13 +200,12 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
- @defer.inlineCallbacks
- def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@@ -236,8 +231,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (yield self.remote_media_linearizer.queue(key)):
- responder, media_info = yield self._get_remote_media_impl(
+ with (await self.remote_media_linearizer.queue(key)):
+ responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -246,14 +241,13 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
- yield respond_with_responder(
+ await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
- @defer.inlineCallbacks
- def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -274,8 +268,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (yield self.remote_media_linearizer.queue(key)):
- responder, media_info = yield self._get_remote_media_impl(
+ with (await self.remote_media_linearizer.queue(key)):
+ responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -286,8 +280,7 @@ class MediaRepository(object):
return media_info
- @defer.inlineCallbacks
- def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -299,7 +292,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
- media_info = yield self.store.get_cached_remote_media(server_name, media_id)
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -317,19 +310,18 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
- media_info = yield self._download_remote_file(server_name, media_id, file_id)
+ media_info = await self._download_remote_file(server_name, media_id, file_id)
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- @defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -351,7 +343,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
- length, headers = yield self.client.get_file(
+ length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
@@ -397,7 +389,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
- yield finish()
+ await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@@ -405,7 +397,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
- yield self.store.store_cached_remote_media(
+ await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@@ -423,7 +415,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
+ await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@@ -458,16 +450,15 @@ class MediaRepository(object):
return t_byte_source
- @defer.inlineCallbacks
- def generate_local_exact_thumbnail(
+ async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -490,7 +481,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -500,22 +491,21 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
- yield self.store.store_local_thumbnail(
+ await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
- @defer.inlineCallbacks
- def generate_remote_exact_thumbnail(
+ async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -537,7 +527,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -547,7 +537,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
- yield self.store.store_remote_media_thumbnail(
+ await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -560,8 +550,7 @@ class MediaRepository(object):
return output_path
- @defer.inlineCallbacks
- def _generate_thumbnails(
+ async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@@ -582,7 +571,7 @@ class MediaRepository(object):
if not requirements:
return
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@@ -600,7 +589,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
- m_width, m_height = yield defer_to_thread(
+ m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@@ -620,11 +609,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@@ -646,7 +635,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -656,7 +645,7 @@ class MediaRepository(object):
# Write to database
if server_name:
- yield self.store.store_remote_media_thumbnail(
+ await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -667,15 +656,14 @@ class MediaRepository(object):
t_len,
)
else:
- yield self.store.store_local_thumbnail(
+ await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
- @defer.inlineCallbacks
- def delete_old_remote_media(self, before_ts):
- old_media = yield self.store.get_remote_media_before(before_ts)
+ async def delete_old_remote_media(self, before_ts):
+ old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -689,7 +677,7 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
- with (yield self.remote_media_linearizer.queue(key)):
+ with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
@@ -705,7 +693,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
- yield self.store.delete_remote_media(origin, media_id)
+ await self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 07e395cfd1..f206605727 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -86,6 +86,7 @@ class PreviewUrlResource(DirectServeResource):
self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
+ self.url_preview_accept_language = hs.config.url_preview_accept_language
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
@@ -165,8 +166,7 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
- @defer.inlineCallbacks
- def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@@ -179,7 +179,7 @@ class PreviewUrlResource(DirectServeResource):
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
- cache_result = yield self.store.get_url_cache(url, ts)
+ cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@@ -192,13 +192,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
- media_info = yield self._download_url(url, user)
+ media_info = await self._download_url(url, user)
logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
- dims = yield self.media_repo._generate_thumbnails(
+ dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@@ -248,14 +248,14 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
- image_info = yield self._download_url(
+ image_info = await self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
- dims = yield self.media_repo._generate_thumbnails(
+ dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
@@ -293,7 +293,7 @@ class PreviewUrlResource(DirectServeResource):
jsonog = json.dumps(og)
# store OG in history-aware DB cache
- yield self.store.store_url_cache(
+ await self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@@ -305,8 +305,7 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
- @defer.inlineCallbacks
- def _download_url(self, url, user):
+ async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -317,9 +316,12 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- logger.debug("Trying to get url '%s'", url)
- length, headers, uri, code = yield self.client.get_file(
- url, output_stream=f, max_size=self.max_spider_size
+ logger.debug("Trying to get preview for url '%s'", url)
+ length, headers, uri, code = await self.client.get_file(
+ url,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
)
except SynapseError:
# Pass SynapseErrors through directly, so that the servlet
@@ -345,7 +347,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
- yield finish()
+ await finish()
try:
if b"Content-Type" in headers:
@@ -356,7 +358,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
- yield self.store.store_local_media(
+ await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -393,22 +395,21 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
- @defer.inlineCallbacks
- def _expire_url_cache_data(self):
+ async def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
now = self.clock.time_msec()
- logger.info("Running url preview cache expiry")
+ logger.debug("Running url preview cache expiry")
- if not (yield self.store.db.updates.has_completed_background_updates()):
+ if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
- media_ids = yield self.store.get_expired_url_cache(now)
+ media_ids = await self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@@ -430,17 +431,19 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
- yield self.store.delete_url_cache(removed_media)
+ await self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
+ else:
+ logger.debug("No entries removed from url cache")
# Now we delete old images associated with the url cache.
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
- media_ids = yield self.store.get_url_cache_media_before(expire_before)
+ media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@@ -478,9 +481,12 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
- yield self.store.delete_url_cache_media(removed_media)
+ await self.store.delete_url_cache_media(removed_media)
- logger.info("Deleted %d media from url cache", len(removed_media))
+ if removed_media:
+ logger.info("Deleted %d media from url cache", len(removed_media))
+ else:
+ logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None):
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d57480f761..0b87220234 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
)
self.media_repo.mark_recently_accessed(server_name, media_id)
- @defer.inlineCallbacks
- def _respond_local_thumbnail(
+ async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
- thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(request, responder, t_type, t_length)
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
- @defer.inlineCallbacks
- def _select_or_generate_local_thumbnail(
+ async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
- thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
- yield respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
- file_path = yield self.media_repo.generate_local_exact_thumbnail(
+ file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
- yield respond_with_file(request, desired_type, file_path)
+ await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
- @defer.inlineCallbacks
- def _select_or_generate_remote_thumbnail(
+ async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
- media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+ media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
- thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
- yield respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
- file_path = yield self.media_repo.generate_remote_exact_thumbnail(
+ file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
- yield respond_with_file(request, desired_type, file_path)
+ await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
- @defer.inlineCallbacks
- def _respond_remote_thumbnail(
+ async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
- media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+ media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
- thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(request, responder, t_type, t_length)
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py
new file mode 100644
index 0000000000..d958dd65bb
--- /dev/null
+++ b/synapse/rest/oidc/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCResource(Resource):
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
new file mode 100644
index 0000000000..c03194f001
--- /dev/null
+++ b/synapse/rest/oidc/callback_resource.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.http.server import DirectServeResource, wrap_html_request_handler
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCCallbackResource(DirectServeResource):
+ isLeaf = 1
+
+ def __init__(self, hs):
+ super().__init__()
+ self._oidc_handler = hs.get_oidc_handler()
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self._oidc_handler.handle_oidc_callback(request)
|