diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 2e81eeff65..5d9cdf4bde 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.rest.admin
from synapse.http.server import JsonResource
+from synapse.rest import admin
from synapse.rest.client import versions
from synapse.rest.client.v1 import (
directory,
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
+ shared_rooms,
sync,
tags,
thirdparty,
@@ -123,6 +124,7 @@ class ClientRestResource(JsonResource):
password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
- synapse.rest.admin.register_servlets_for_client_rest_resource(
- hs, client_resource
- )
+ admin.register_servlets_for_client_rest_resource(hs, client_resource)
+
+ # unstable
+ shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 1c88c93f38..57cac22252 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -16,13 +16,13 @@
import logging
import platform
-import re
import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
historical_admin_path_patterns,
)
@@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
DeviceRestServlet,
DevicesRestServlet,
)
+from synapse.rest.admin.event_reports import EventReportsRestServlet
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -49,6 +50,7 @@ from synapse.rest.admin.users import (
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
+ UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
UsersRestServlet,
@@ -61,7 +63,7 @@ logger = logging.getLogger(__name__)
class VersionServlet(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
+ PATTERNS = admin_patterns("/server_version$")
def __init__(self, hs):
self.res = {
@@ -107,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- token = await self.store.get_topological_token_for_event(event_id)
+ room_token = await self.store.get_topological_token_for_event(event_id)
+ token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:
@@ -209,11 +212,13 @@ def register_servlets(hs, http_server):
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
+ UserMembershipRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
+ EventReportsRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d82eaf5e38..db9fea263a 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex):
]
-def admin_patterns(path_regex: str):
+def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
Args:
@@ -54,7 +54,7 @@ def admin_patterns(path_regex: str):
Returns:
A list of regex patterns.
"""
- admin_prefix = "^/_synapse/admin/v1"
+ admin_prefix = "^/_synapse/admin/" + version
patterns = [re.compile(admin_prefix + path_regex)]
return patterns
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 8d32677339..a163863322 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import re
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -21,7 +20,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
-from synapse.rest.admin._base import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -32,14 +31,12 @@ class DeviceRestServlet(RestServlet):
Get, update or delete the given user's device
"""
- PATTERNS = (
- re.compile(
- "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
- ),
+ PATTERNS = admin_patterns(
+ "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
def __init__(self, hs):
- super(DeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet):
Retrieve the given user's devices
"""
- PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs):
"""
@@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet):
key which lists the device_ids to delete.
"""
- PATTERNS = (
- re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
- )
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs):
self.hs = hs
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
new file mode 100644
index 0000000000..5b8d0594cd
--- /dev/null
+++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+logger = logging.getLogger(__name__)
+
+
+class EventReportsRestServlet(RestServlet):
+ """
+ List all reported events that are known to the homeserver. Results are returned
+ in a dictionary containing report information. Supports pagination.
+ The requester must have administrator access in Synapse.
+
+ GET /_synapse/admin/v1/event_reports
+ returns:
+ 200 OK with list of reports if success otherwise an error.
+
+ Args:
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `dir` can be used to define the order of results.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `room_id` can be used to filter by room id.
+ Returns:
+ A list of reported events and an integer representing the total number of
+ reported events that exist given this query
+ """
+
+ PATTERNS = admin_patterns("/event_reports$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ direction = parse_string(request, "dir", default="b")
+ user_id = parse_string(request, "user_id")
+ room_id = parse_string(request, "room_id")
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "The start parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "The limit parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ event_reports, total = await self.store.get_event_reports_paginate(
+ start, limit, direction, user_id, room_id
+ )
+ ret = {"event_reports": event_reports, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(event_reports)
+
+ return 200, ret
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index f474066542..8b7bb6d44e 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,14 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
class PurgeRoomServlet(RestServlet):
@@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
+ PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b8c95d045a..09726d52d6 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -31,7 +31,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
historical_admin_path_patterns,
)
-from synapse.storage.data_stores.main.room import RoomSortOrder
+from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
logger = logging.getLogger(__name__)
@@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ purge = content.get("purge", True)
+ if not isinstance(purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet):
)
# Purge room
- await self.pagination_handler.purge_room(room_id)
+ if purge:
+ await self.pagination_handler.purge_room(room_id)
return (200, ret)
@@ -307,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet):
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+ # update_membership with an action of "invite" can raise a
+ # ShadowBanError. This is not handled since it is assumed that
+ # an admin isn't going to call this API with a shadow-banned user.
await self.room_member_handler.update_membership(
requester=requester,
target=fake_requester.user,
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 6e9a874121..375d055445 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -22,6 +20,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import UserID
@@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet):
self.snm = hs.get_server_notices_manager()
def register(self, json_resource):
- PATTERN = "^/_synapse/admin/v1/send_server_notice"
+ PATTERN = "/send_server_notice"
json_resource.register_paths(
- "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
+ "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
- (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
+ admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"),
self.on_PUT,
self.__class__.__name__,
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index cc0bdfa5c9..20dc1d0e05 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -15,7 +15,6 @@
import hashlib
import hmac
import logging
-import re
from http import HTTPStatus
from synapse.api.constants import UserTypes
@@ -29,6 +28,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
@@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet):
class UsersRestServletV2(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+ PATTERNS = admin_patterns("/users$", "v2")
"""Get request to list all local users.
This needs user to have administrator access in Synapse.
@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
The parameters `from` and `limit` are required only for pagination.
By default, a `limit` of 100 is used.
The parameter `user_id` can be used to filter by user id.
+ The parameter `name` can be used to filter by user id or display name.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
"""
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id", default=None)
+ name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
users, total = await self.store.get_users_paginate(
- start, limit, user_id, guests, deactivated
+ start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if len(users) >= limit:
@@ -103,7 +105,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@@ -640,7 +642,7 @@ class UserAdminServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs):
self.hs = hs
@@ -681,3 +683,29 @@ class UserAdminServlet(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
return 200, {}
+
+
+class UserMembershipRestServlet(RestServlet):
+ """
+ Get room list of an user.
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+
+ def __init__(self, hs):
+ self.is_mine = hs.is_mine
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only lookup local users")
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ raise NotFoundError("User not found")
+
+ ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
+ return 200, ret
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 6da71dc46f..7be5c0fb88 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
-class HttpTransactionCache(object):
+class HttpTransactionCache:
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 5934b1fe8b..faabeeb91c 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
try:
- service = await self.auth.get_appservice_by_req(request)
+ service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
@@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryListServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
)
def __init__(self, hs):
- super(ClientAppserviceDirectoryListServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 25effd0261..1ecb77aa26 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -30,9 +30,10 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
- super(EventStreamRestServlet, self).__init__()
+ super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii")
- pagin_config = PaginationConfig.from_request(request)
+ pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args:
try:
@@ -74,7 +75,7 @@ class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(EventRestServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 910b3b4eeb..91da0ee573 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -24,14 +24,15 @@ class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
- super(InitialSyncRestServlet, self).__init__()
+ super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
- pagination_config = PaginationConfig.from_request(request)
+ pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 379f668d6f..3d1693d7ac 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,11 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.appservice import ApplicationService
+from synapse.handlers.auth import (
+ convert_client_dict_legacy_fields_to_identifier,
+ login_id_phone_to_thirdparty,
+)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -28,56 +33,11 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
-from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
-def login_submission_legacy_convert(submission):
- """
- If the input login submission is an old style object
- (ie. with top-level user / medium / address) convert it
- to a typed object.
- """
- if "user" in submission:
- submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
- del submission["user"]
-
- if "medium" in submission and "address" in submission:
- submission["identifier"] = {
- "type": "m.id.thirdparty",
- "medium": submission["medium"],
- "address": submission["address"],
- }
- del submission["medium"]
- del submission["address"]
-
-
-def login_id_thirdparty_from_phone(identifier):
- """
- Convert a phone login identifier type to a generic threepid identifier
- Args:
- identifier(dict): Login identifier dict of type 'm.id.phone'
-
- Returns: Login identifier dict of type 'm.id.threepid'
- """
- if "country" not in identifier or (
- # The specification requires a "phone" field, while Synapse used to require a "number"
- # field. Accept both for backwards compatibility.
- "phone" not in identifier
- and "number" not in identifier
- ):
- raise SynapseError(400, "Invalid phone-type identifier")
-
- # Accept both "phone" and "number" as valid keys in m.id.phone
- phone_number = identifier.get("phone", identifier["number"])
-
- msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
-
- return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
-
-
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -85,9 +45,10 @@ class LoginRestServlet(RestServlet):
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
+ APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
def __init__(self, hs):
- super(LoginRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
# JWT configuration variables.
@@ -102,6 +63,8 @@ class LoginRestServlet(RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self.auth = hs.get_auth()
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -157,8 +120,12 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
+
try:
- if self.jwt_enabled and (
+ if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
+ appservice = self.auth.get_appservice_by_req(request)
+ result = await self._do_appservice_login(login_submission, appservice)
+ elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
@@ -175,6 +142,33 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
+ def _get_qualified_user_id(self, identifier):
+ if identifier["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+ if "user" not in identifier:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if identifier["user"].startswith("@"):
+ return identifier["user"]
+ else:
+ return UserID(identifier["user"], self.hs.hostname).to_string()
+
+ async def _do_appservice_login(
+ self, login_submission: JsonDict, appservice: ApplicationService
+ ):
+ logger.info(
+ "Got appservice login request with identifier: %r",
+ login_submission.get("identifier"),
+ )
+
+ identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
+ qualified_user_id = self._get_qualified_user_id(identifier)
+
+ if not appservice.is_interested_in_user(qualified_user_id):
+ raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
+
+ return await self._complete_login(qualified_user_id, login_submission)
+
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,18 +188,11 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
- login_submission_legacy_convert(login_submission)
-
- if "identifier" not in login_submission:
- raise SynapseError(400, "Missing param: identifier")
-
- identifier = login_submission["identifier"]
- if "type" not in identifier:
- raise SynapseError(400, "Login identifier has no type")
+ identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
- identifier = login_id_thirdparty_from_phone(identifier)
+ identifier = login_id_phone_to_thirdparty(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
@@ -267,15 +254,7 @@ class LoginRestServlet(RestServlet):
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- qualified_user_id = identifier["user"]
- else:
- qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+ qualified_user_id = self._get_qualified_user_id(identifier)
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
@@ -303,9 +282,7 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
- callback: Optional[
- Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
- ] = None,
+ callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
@@ -318,12 +295,12 @@ class LoginRestServlet(RestServlet):
Args:
user_id: ID of the user to register.
login_submission: Dictionary of login information.
- callback: Callback function to run after registration.
+ callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
Returns:
- result: Dictionary of account information after successful registration.
+ result: Dictionary of account information after successful login.
"""
# Before we actually log them in we check if they've already logged in
@@ -358,14 +335,24 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ """
+ Handle the final stage of SSO login.
+
+ Args:
+ login_submission: The JSON request body.
+
+ Returns:
+ The body of the JSON response.
+ """
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
- result = await self._complete_login(user_id, login_submission)
- return result
+ return await self._complete_login(
+ user_id, login_submission, self.auth_handler._sso_login_callback
+ )
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
@@ -448,7 +435,7 @@ class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
- super(CasTicketServlet, self).__init__()
+ super().__init__()
self._cas_handler = hs.get_cas_handler()
async def on_GET(self, request: SynapseRequest) -> None:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b0c30b65be..f792b50cdc 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
- super(LogoutRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
- super(LogoutAllRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index ceaa28c212..4796cdac05 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
- super(PresenceStatusRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 165313b572..204b2ec9e5 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -26,7 +26,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
- super(ProfileDisplaynameRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.http_client = hs.get_simple_http_client()
@@ -91,7 +91,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
- super(ProfileAvatarURLRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.http_client = hs.get_simple_http_client()
@@ -159,7 +159,7 @@ class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
- super(ProfileRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9fd4908136..f9eecb7cf5 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.api.errors import (
NotFoundError,
StoreError,
@@ -25,7 +24,7 @@ from synapse.http.servlet import (
parse_json_value_from_request,
parse_string,
)
-from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -39,12 +38,14 @@ class PushRuleRestServlet(RestServlet):
)
def __init__(self, hs):
- super(PushRuleRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+
async def on_PUT(self, request, path):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -158,10 +159,22 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
def notify_user(self, user_id):
- stream_id, _ = self.store.get_push_rules_stream_token()
+ stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
- def set_rule_attr(self, user_id, spec, val):
+ async def set_rule_attr(self, user_id, spec, val):
+ if spec["attr"] not in ("enabled", "actions"):
+ # for the sake of potential future expansion, shouldn't report
+ # 404 in the case of an unknown request so check it corresponds to
+ # a known attribute first.
+ raise UnrecognizedRequestError()
+
+ namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+ rule_id = spec["rule_id"]
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -170,8 +183,9 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+ return await self.store.set_push_rule_enabled(
+ user_id, namespaced_rule_id, val, is_default_rule
+ )
elif spec["attr"] == "actions":
actions = val.get("actions")
_check_actions(actions)
@@ -179,9 +193,14 @@ class PushRuleRestServlet(RestServlet):
rule_id = spec["rule_id"]
is_default_rule = rule_id.startswith(".")
if is_default_rule:
- if namespaced_rule_id not in BASE_RULE_IDS:
+ if user_id in self._users_new_default_push_rules:
+ rule_ids = NEW_RULE_IDS
+ else:
+ rule_ids = BASE_RULE_IDS
+
+ if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- return self.store.set_push_rule_actions(
+ return await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 5f65cb7d83..28dabf1c7a 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
- super(PushersRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
- super(PushersSetRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(PushersRemoveRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 1a3398316d..b421fe855e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,14 +21,13 @@ import re
from typing import List, Optional
from urllib import parse as urlparse
-from canonicaljson import json
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
HttpResponseException,
InvalidClientCredentialsError,
+ ShadowBanError,
SynapseError,
)
from synapse.api.filtering import Filter
@@ -46,6 +45,8 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.util import json_decoder
+from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
@@ -56,7 +57,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
def __init__(self, hs):
- super(TransactionRestServlet, self).__init__()
+ super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -64,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
- super(RoomCreateRestServlet, self).__init__(hs)
+ super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
@@ -110,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomStateEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@@ -170,7 +171,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
room_id=room_id,
event_type=event_type,
state_key=state_key,
- is_guest=requester.is_guest,
)
if not data:
@@ -200,23 +200,26 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
- if event_type == EventTypes.Member:
- membership = content.get("membership", None)
- event_id, _ = await self.room_member_handler.update_membership(
- requester,
- target=UserID.from_string(state_key),
- room_id=room_id,
- action=membership,
- content=content,
- )
- else:
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, txn_id=txn_id
- )
- event_id = event.event_id
+ try:
+ if event_type == EventTypes.Member:
+ membership = content.get("membership", None)
+ event_id, _ = await self.room_member_handler.update_membership(
+ requester,
+ target=UserID.from_string(state_key),
+ room_id=room_id,
+ action=membership,
+ content=content,
+ )
+ else:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
set_tag("event_id", event_id)
ret = {"event_id": event_id}
@@ -226,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomSendEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -249,12 +252,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, txn_id=txn_id
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- set_tag("event_id", event.event_id)
- return 200, {"event_id": event.event_id}
+ set_tag("event_id", event_id)
+ return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
@@ -270,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
- super(JoinRoomAliasServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -333,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
- super(PublicRoomListRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@@ -438,13 +448,14 @@ class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
- super(RoomMemberListRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -455,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
if at_token_string is None:
at_token = None
else:
- at_token = StreamToken.from_string(at_token_string)
+ at_token = await StreamToken.from_string(self.store, at_token_string)
# let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter
@@ -489,7 +500,7 @@ class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
- super(JoinedRoomMemberListRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@@ -508,18 +519,23 @@ class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
- super(RoomMessageListRestServlet, self).__init__()
+ super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(request, default_limit=10)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@@ -545,7 +561,7 @@ class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
- super(RoomStateRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@@ -565,13 +581,14 @@ class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
- super(RoomInitialSyncRestServlet, self).__init__()
+ super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(request)
+ pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
@@ -584,7 +601,7 @@ class RoomEventServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomEventServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
@@ -616,7 +633,7 @@ class RoomEventContextServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomEventContextServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
@@ -631,7 +648,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
else:
event_filter = None
@@ -661,7 +680,7 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomForgetRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -687,7 +706,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomMembershipRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -716,17 +735,21 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- await self.room_member_handler.do_3pid_invite(
- room_id,
- requester.user,
- content["medium"],
- content["address"],
- content["id_server"],
- requester,
- txn_id,
- new_room=False,
- id_access_token=content.get("id_access_token"),
- )
+ try:
+ await self.room_member_handler.do_3pid_invite(
+ room_id,
+ requester.user,
+ content["medium"],
+ content["address"],
+ content["id_server"],
+ requester,
+ txn_id,
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
+ )
+ except ShadowBanError:
+ # Pretend the request succeeded.
+ pass
return 200, {}
target = requester.user
@@ -738,15 +761,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content:
event_content = {"reason": content["reason"]}
- await self.room_member_handler.update_membership(
- requester=requester,
- target=target,
- room_id=room_id,
- action=membership_action,
- txn_id=txn_id,
- third_party_signed=content.get("third_party_signed", None),
- content=event_content,
- )
+ try:
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=target,
+ room_id=room_id,
+ action=membership_action,
+ txn_id=txn_id,
+ third_party_signed=content.get("third_party_signed", None),
+ content=event_content,
+ )
+ except ShadowBanError:
+ # Pretend the request succeeded.
+ pass
return_value = {}
@@ -771,7 +798,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomRedactEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -784,20 +811,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Redaction,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "redacts": event_id,
- },
- txn_id=txn_id,
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "redacts": event_id,
+ },
+ txn_id=txn_id,
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- set_tag("event_id", event.event_id)
- return 200, {"event_id": event.event_id}
+ set_tag("event_id", event_id)
+ return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)
@@ -813,7 +847,7 @@ class RoomTypingRestServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomTypingRestServlet, self).__init__()
+ super().__init__()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@@ -840,17 +874,21 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
- if content["typing"]:
- await self.typing_handler.started_typing(
- target_user=target_user,
- auth_user=requester.user,
- room_id=room_id,
- timeout=timeout,
- )
- else:
- await self.typing_handler.stopped_typing(
- target_user=target_user, auth_user=requester.user, room_id=room_id
- )
+ try:
+ if content["typing"]:
+ await self.typing_handler.started_typing(
+ target_user=target_user,
+ requester=requester,
+ room_id=room_id,
+ timeout=timeout,
+ )
+ else:
+ await self.typing_handler.stopped_typing(
+ target_user=target_user, requester=requester, room_id=room_id
+ )
+ except ShadowBanError:
+ # Pretend this worked without error.
+ pass
return 200, {}
@@ -882,7 +920,7 @@ class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
- super(SearchRestServlet, self).__init__()
+ super().__init__()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -903,7 +941,7 @@ class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
- super(JoinedRoomsRestServlet, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 50277c6cf6..b8d491ca5c 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
- super(VoipRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d4b1ee1e8c..1320aad8f6 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
from http import HTTPStatus
from typing import TYPE_CHECKING
@@ -23,10 +24,13 @@ from urllib.parse import urlparse
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
-from twisted.internet import defer
-
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ InteractiveAuthIncompleteError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
@@ -35,7 +39,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -50,28 +54,18 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
- super(EmailPasswordRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
self.config = hs.config
self.identity_handler = hs.get_handlers().identity_handler
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_password_reset_template_html,
- self.config.email_password_reset_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_password_reset_template_html,
+ template_text=self.config.email_password_reset_template_text,
)
async def on_POST(self, request):
@@ -104,13 +98,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
- raise SynapseError(
- 403,
- "Your email is not authorized on this server",
- Codes.THREEPID_DENIED,
- )
-
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -127,6 +114,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -158,87 +148,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret
-class PasswordResetSubmitTokenServlet(RestServlet):
- """Handles 3PID validation token submission"""
-
- PATTERNS = client_patterns(
- "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
- )
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- super(PasswordResetSubmitTokenServlet, self).__init__()
- self.hs = hs
- self.auth = hs.get_auth()
- self.config = hs.config
- self.clock = hs.get_clock()
- self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_password_reset_template_failure_html],
- )
-
- async def on_GET(self, request, medium):
- # We currently only handle threepid token submissions for email
- if medium != "email":
- raise SynapseError(
- 400, "This medium is currently not supported for password resets"
- )
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Password reset emails have been disabled due to lack of an email config"
- )
- raise SynapseError(
- 400, "Email-based password resets are disabled on this server"
- )
-
- sid = parse_string(request, "sid", required=True)
- token = parse_string(request, "token", required=True)
- client_secret = parse_string(request, "client_secret", required=True)
- assert_valid_client_secret(client_secret)
-
- # Attempt to validate a 3PID session
- try:
- # Mark the session as valid
- next_link = await self.store.validate_threepid_session(
- sid, client_secret, token, self.clock.time_msec()
- )
-
- # Perform a 302 redirect if next_link is set
- if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
-
- # Otherwise show the success template
- html = self.config.email_password_reset_template_success_html
- status_code = 200
- except ThreepidValidationError as e:
- status_code = e.code
-
- # Show a failure page with a reason
- template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
-
- respond_with_html(request, status_code, html)
-
-
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
def __init__(self, hs):
- super(PasswordRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -253,18 +167,12 @@ class PasswordRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- if "new_password" in body:
- new_password = body.pop("new_password")
+ new_password = body.pop("new_password", None)
+ if new_password is not None:
if not isinstance(new_password, str) or len(new_password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "new_password_hash" in body:
- raise SynapseError(400, "Unexpected property: new_password_hash")
- body["new_password_hash"] = await self.auth_handler.hash(new_password)
-
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -281,23 +189,52 @@ class PasswordRestServlet(RestServlet):
if requester.app_service:
params = body
else:
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
+ try:
+ (
+ params,
+ session_id,
+ ) = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
+ user_id = requester.user.to_string()
+ else:
+ requester = None
+ try:
+ result, params, session_id = await self.auth_handler.check_ui_auth(
+ [[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
- user_id = requester.user.to_string()
- else:
- requester = None
- result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]],
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
@@ -322,32 +259,40 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password_hash"])
- new_password_hash = params["new_password_hash"]
+ # If we have a password in this request, prefer it. Otherwise, there
+ # must be a password hash from an earlier request.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ else:
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password_hash, logout_devices, requester
+ user_id, password_hash, logout_devices, requester
)
if self.hs.config.shadow_server:
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
- self.shadow_password(params, shadow_user.to_string())
+ await self.shadow_password(params, shadow_user.to_string())
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
- @defer.inlineCallbacks
- def shadow_password(self, body, user_id):
+ async def shadow_password(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -358,7 +303,7 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs):
- super(DeactivateAccountRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -377,7 +322,7 @@ class DeactivateAccountRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
- # allow ASes to dectivate their own users
+ # allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
@@ -406,26 +351,18 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
- super(EmailThreepidRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.config = hs.config
self.identity_handler = hs.get_handlers().identity_handler
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_add_threepid_template_html,
- self.config.email_add_threepid_template_text,
- ],
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_add_threepid_template_html,
+ template_text=self.config.email_add_threepid_template_text,
)
async def on_POST(self, request):
@@ -473,6 +410,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -509,7 +449,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
- super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@@ -545,6 +485,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -588,9 +531,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_add_threepid_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_add_threepid_template_failure_html
)
async def on_GET(self, request):
@@ -636,7 +578,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -687,7 +629,7 @@ class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs):
- super(ThreepidRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -731,7 +673,7 @@ class ThreepidRestServlet(RestServlet):
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
- self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
@@ -766,7 +708,7 @@ class ThreepidRestServlet(RestServlet):
"address": validation_session["address"],
"validated_at": validation_session["validated_at"],
}
- self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
@@ -774,13 +716,12 @@ class ThreepidRestServlet(RestServlet):
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
- @defer.inlineCallbacks
- def shadow_3pid(self, body, user_id):
+ async def shadow_3pid(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -791,7 +732,7 @@ class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
def __init__(self, hs):
- super(ThreepidAddRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -841,20 +782,19 @@ class ThreepidAddRestServlet(RestServlet):
"address": validation_session["address"],
"validated_at": validation_session["validated_at"],
}
- self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
- @defer.inlineCallbacks
- def shadow_3pid(self, body, user_id):
+ async def shadow_3pid(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -865,7 +805,7 @@ class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
def __init__(self, hs):
- super(ThreepidBindRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -894,7 +834,7 @@ class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")
def __init__(self, hs):
- super(ThreepidUnbindRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -925,7 +865,7 @@ class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs):
- super(ThreepidDeleteRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -958,7 +898,7 @@ class ThreepidDeleteRestServlet(RestServlet):
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
- self.shadow_3pid_delete(body, shadow_user.to_string())
+ await self.shadow_3pid_delete(body, shadow_user.to_string())
if ret:
id_server_unbind_result = "success"
@@ -967,13 +907,12 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
- @defer.inlineCallbacks
- def shadow_3pid_delete(self, body, user_id):
+ async def shadow_3pid_delete(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -988,12 +927,11 @@ class ThreepidLookupRestServlet(RestServlet):
self.auth = hs.get_auth()
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
"""Proxy a /_matrix/identity/api/v1/lookup request to an identity
server
"""
- yield self.auth.get_user_by_req(request)
+ await self.auth.get_user_by_req(request)
# Verify query parameters
query_params = request.args
@@ -1006,9 +944,9 @@ class ThreepidLookupRestServlet(RestServlet):
# Proxy the request to the identity server. lookup_3pid handles checking
# if the lookup is allowed so we don't need to do it here.
- ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+ ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
- defer.returnValue((200, ret))
+ return 200, ret
class ThreepidBulkLookupRestServlet(RestServlet):
@@ -1019,12 +957,11 @@ class ThreepidBulkLookupRestServlet(RestServlet):
self.auth = hs.get_auth()
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
server
"""
- yield self.auth.get_user_by_req(request)
+ await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
@@ -1032,11 +969,11 @@ class ThreepidBulkLookupRestServlet(RestServlet):
# Proxy the request to the identity server. lookup_3pid handles checking
# if the lookup is allowed so we don't need to do it here.
- ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+ ret = await self.identity_handler.proxy_bulk_lookup_3pid(
body["id_server"], body["threepids"]
)
- defer.returnValue((200, ret))
+ return 200, ret
def assert_valid_next_link(hs: "HomeServer", next_link: str):
@@ -1082,7 +1019,7 @@ class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs):
- super(WhoamiRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
async def on_GET(self, request):
@@ -1093,7 +1030,6 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index d31ec7c29d..617ee6d62a 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -35,7 +35,7 @@ class AccountDataServlet(RestServlet):
)
def __init__(self, hs):
- super(AccountDataServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -93,7 +93,7 @@ class RoomAccountDataServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomAccountDataServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d06336ceea..bd7f9ae203 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValidityRenewServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
@@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValiditySendMailServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e585e9153..5fbfae5991 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -25,94 +25,6 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
-RECAPTCHA_TEMPLATE = """
-<html>
-<head>
-<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<script src="https://www.recaptcha.net/recaptcha/api.js"
- async defer></script>
-<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-function captchaDone() {
- $('#registrationForm').submit();
-}
-</script>
-</head>
-<body>
-<form id="registrationForm" method="post" action="%(myurl)s">
- <div>
- <p>
- Hello! We need to prevent computer programs and other automated
- things from creating accounts on this server.
- </p>
- <p>
- Please verify that you're not a robot.
- </p>
- <input type="hidden" name="session" value="%(session)s" />
- <div class="g-recaptcha"
- data-sitekey="%(sitekey)s"
- data-callback="captchaDone">
- </div>
- <noscript>
- <input type="submit" value="All Done" />
- </noscript>
- </div>
- </div>
-</form>
-</body>
-</html>
-"""
-
-TERMS_TEMPLATE = """
-<html>
-<head>
-<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-</head>
-<body>
-<form id="registrationForm" method="post" action="%(myurl)s">
- <div>
- <p>
- Please click the button below if you agree to the
- <a href="%(terms_url)s">privacy policy of this homeserver.</a>
- </p>
- <input type="hidden" name="session" value="%(session)s" />
- <input type="submit" value="Agree" />
- </div>
-</form>
-</body>
-</html>
-"""
-
-SUCCESS_TEMPLATE = """
-<html>
-<head>
-<title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-if (window.onAuthDone) {
- window.onAuthDone();
-} else if (window.opener && window.opener.postMessage) {
- window.opener.postMessage("authDone", "*");
-}
-</script>
-</head>
-<body>
- <div>
- <p>Thank you</p>
- <p>You may now close this window and return to the application</p>
- </div>
-</body>
-</html>
-"""
-
class AuthRestServlet(RestServlet):
"""
@@ -124,7 +36,7 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
- super(AuthRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -145,26 +57,30 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
+ self.recaptcha_template = hs.config.recaptcha_template
+ self.terms_template = hs.config.terms_template
+ self.success_template = hs.config.fallback_success_template
+
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
if stagetype == LoginType.RECAPTCHA:
- html = RECAPTCHA_TEMPLATE % {
- "session": session,
- "myurl": "%s/r0/auth/%s/fallback/web"
+ html = self.recaptcha_template.render(
+ session=session,
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- "sitekey": self.hs.config.recaptcha_public_key,
- }
+ sitekey=self.hs.config.recaptcha_public_key,
+ )
elif stagetype == LoginType.TERMS:
- html = TERMS_TEMPLATE % {
- "session": session,
- "terms_url": "%s_matrix/consent?v=%s"
+ html = self.terms_template.render(
+ session=session,
+ terms_url="%s_matrix/consent?v=%s"
% (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
- "myurl": "%s/r0/auth/%s/fallback/web"
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
- }
+ )
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
@@ -222,14 +138,14 @@ class AuthRestServlet(RestServlet):
)
if success:
- html = SUCCESS_TEMPLATE
+ html = self.success_template.render()
else:
- html = RECAPTCHA_TEMPLATE % {
- "session": session,
- "myurl": "%s/r0/auth/%s/fallback/web"
+ html = self.recaptcha_template.render(
+ session=session,
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- "sitekey": self.hs.config.recaptcha_public_key,
- }
+ sitekey=self.hs.config.recaptcha_public_key,
+ )
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -238,18 +154,18 @@ class AuthRestServlet(RestServlet):
)
if success:
- html = SUCCESS_TEMPLATE
+ html = self.success_template.render()
else:
- html = TERMS_TEMPLATE % {
- "session": session,
- "terms_url": "%s_matrix/consent?v=%s"
+ html = self.terms_template.render(
+ session=session,
+ terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- "myurl": "%s/r0/auth/%s/fallback/web"
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
- }
+ )
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fe9d019c44..76879ac559 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(CapabilitiesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index c0714fcfb1..7e174de692 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(DevicesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
- super(DeleteDevicesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(DeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b28da017cd..7cc692643b 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
- super(GetFilterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
- super(CreateFilterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d84a6d7e11..a3bb095c2d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -31,7 +32,7 @@ class GroupServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
- super(GroupServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -65,7 +66,7 @@ class GroupSummaryServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
- super(GroupSummaryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -96,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupSummaryRoomsCatServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -136,7 +137,7 @@ class GroupCategoryServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupCategoryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -180,7 +181,7 @@ class GroupCategoriesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
def __init__(self, hs):
- super(GroupCategoriesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -203,7 +204,7 @@ class GroupRoleServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
def __init__(self, hs):
- super(GroupRoleServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -247,7 +248,7 @@ class GroupRolesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
def __init__(self, hs):
- super(GroupRolesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -278,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupSummaryUsersRoleServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -316,7 +317,7 @@ class GroupRoomServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
- super(GroupRoomServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -339,7 +343,7 @@ class GroupUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
- super(GroupUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -362,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
- super(GroupInvitedUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -385,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
- super(GroupSettingJoinPolicyServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@@ -409,7 +413,7 @@ class GroupCreateServlet(RestServlet):
PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
- super(GroupCreateServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -440,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminRoomsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -477,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminRoomsConfigServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -503,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminUsersInviteServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -532,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminUsersKickServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -556,7 +560,7 @@ class GroupSelfLeaveServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
def __init__(self, hs):
- super(GroupSelfLeaveServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -580,7 +584,7 @@ class GroupSelfJoinServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
def __init__(self, hs):
- super(GroupSelfJoinServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -604,7 +608,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
def __init__(self, hs):
- super(GroupSelfAcceptInviteServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -628,7 +632,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
def __init__(self, hs):
- super(GroupSelfUpdatePublicityServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -651,7 +655,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
def __init__(self, hs):
- super(PublicisedGroupsForUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -672,7 +676,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups$")
def __init__(self, hs):
- super(PublicisedGroupsForUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -696,7 +700,7 @@ class GroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/joined_groups$")
def __init__(self, hs):
- super(GroupsForUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24bb090822..55c4606569 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(KeyUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- super(KeyQueryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -177,9 +177,10 @@ class KeyChangesServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- super(KeyChangesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet):
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
- from_token = StreamToken.from_string(from_token_string)
+ from_token = await StreamToken.from_string(self.store, from_token_string)
user_id = requester.user.to_string()
@@ -222,7 +223,7 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
- super(OneTimeKeyServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -250,7 +251,7 @@ class SigningKeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SigningKeyUploadServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -308,7 +309,7 @@ class SignaturesUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SignaturesUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index aa911d75ee..87063ec8b1 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
- super(NotificationsServlet, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 6ae9a5a8e9..5b996e2d63 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
def __init__(self, hs):
- super(IdTokenServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
index 968403cca4..68b27ff23a 100644
--- a/synapse/rest/client/v2_alpha/password_policy.py
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(PasswordPolicyServlet, self).__init__()
+ super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 67cbc37312..55c6688f52 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
- super(ReadMarkerRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 92555bd4a9..6f7246a394 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet):
)
def __init__(self, hs):
- super(ReceiptRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 001f49fb3e..91ea76bc20 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,6 +17,7 @@
import hmac
import logging
+import random
import re
from typing import List, Union
@@ -26,6 +27,7 @@ import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
+ InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -45,7 +47,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -76,29 +78,17 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(EmailRegisterRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- from synapse.push.mailer import Mailer, load_jinja2_templates
-
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_registration_template_html,
- self.config.email_registration_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_registration_template_html,
+ template_text=self.config.email_registration_template_text,
)
async def on_POST(self, request):
@@ -144,6 +134,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -183,7 +176,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -218,6 +211,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
@@ -257,7 +253,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegistrationSubmitTokenServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.config = hs.config
@@ -265,15 +261,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
- )
-
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_registration_template_failure_html
)
async def on_GET(self, request, medium):
@@ -321,7 +310,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -334,7 +323,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(UsernameAvailabilityRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
@@ -372,7 +361,7 @@ class RegisterRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegisterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -385,6 +374,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_enabled = self.hs.config.enable_registration
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -410,22 +400,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # we do basic sanity checks here because the auth layer will store these
- # in sessions. Pull out the username/password provided to us.
- desired_password_hash = None
- if "password" in body:
- password = body.pop("password")
- if not isinstance(password, str) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(password)
-
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "password_hash" in body:
- raise SynapseError(400, "Unexpected property: password_hash")
- body["password_hash"] = await self.auth_handler.hash(password)
- desired_password_hash = body["password_hash"]
-
# We don't care about usernames for this deployment. In fact, the act
# of checking whether they exist already can leak metadata about
# which users are already registered.
@@ -440,7 +414,12 @@ class RegisterRestServlet(RestServlet):
appservice = None
if self.auth.has_access_token(request):
- appservice = await self.auth.get_appservice_by_req(request)
+ appservice = self.auth.get_appservice_by_req(request)
+
+ # We need to retrieve the password early in order to pass it to
+ # application service registration
+ # This is specific to shadow server registration of users via an AS
+ password = body.pop("password", None)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -459,23 +438,33 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
- if isinstance(desired_username, str):
- result = await self._do_appservice_registration(
- desired_username,
- desired_password_hash,
- desired_display_name,
- access_token,
- body,
- )
- return 200, result # we throw for non 200 responses
+ if not isinstance(desired_username, str):
+ raise SynapseError(400, "Desired Username is missing or not a string")
+
+ result = await self._do_appservice_registration(
+ desired_username, password, desired_display_name, access_token, body
+ )
+
+ return 200, result
# == Normal User Registration == (everyone else)
- if not self.hs.config.enable_registration:
+ if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
+ # Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password_hash" not in body:
+ # Pull out the provided password and do basic sanity checks early.
+ #
+ # Note that we remove the password from the body since the auth layer
+ # will store the body in the session and we don't want a plaintext
+ # password store there.
+ if password is not None:
+ if not isinstance(password, str) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(password)
+
+ if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -485,6 +474,7 @@ class RegisterRestServlet(RestServlet):
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
+ password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
@@ -493,21 +483,43 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
+ # Extract the previously-hashed password from the session.
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
- auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
- )
+ # Check if the user-interactive authentication flows are complete, if
+ # not this will raise a user-interactive auth error.
+ try:
+ auth_result, params, session_id = await self.auth_handler.check_ui_auth(
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth.
+ #
+ # Hash the password and store it with the session since the client
+ # is not required to provide the password again.
+ #
+ # If a password hash was previously stored we will not attempt to
+ # re-hash and store it for efficiency. This assumes the password
+ # does not change throughout the authentication flow, but this
+ # should be fine since the data is meant to be consistent.
+ if not password_hash and password:
+ password_hash = await self.auth_handler.hash(password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
-
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
@@ -603,8 +615,12 @@ class RegisterRestServlet(RestServlet):
# don't re-register the threepids
registered = False
else:
- # NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password_hash"])
+ # If we have a password in this request, prefer it. Otherwise, there
+ # might be a password hash from an earlier request.
+ if password:
+ password_hash = await self.auth_handler.hash(password)
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
if not self.hs.config.register_mxid_from_3pid:
desired_username = params.get("username", None)
@@ -613,7 +629,6 @@ class RegisterRestServlet(RestServlet):
pass
guest_access_token = params.get("guest_access_token", None)
- new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -653,13 +668,18 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
+ entries = await self.store.get_user_agents_ips_to_ui_auth_session(
+ session_id
+ )
+
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password_hash=new_password_hash,
+ password_hash=password_hash,
guest_access_token=guest_access_token,
default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
+ user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
@@ -677,8 +697,8 @@ class RegisterRestServlet(RestServlet):
params=params,
)
- # remember that we've now registered that user account, and with
- # what user ID (since the user may not have specified)
+ # Remember that the user account has been registered (and the user
+ # ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -702,12 +722,20 @@ class RegisterRestServlet(RestServlet):
return 200, {}
async def _do_appservice_registration(
- self, username, password_hash, display_name, as_token, body
+ self, username, password, display_name, as_token, body
):
# FIXME: appservice_register() is horribly duplicated with register()
# and they should probably just be combined together with a config flag.
+
+ if password:
+ # Hash the password
+ #
+ # In mainline hashing of the password was moved further on in the registration
+ # flow, but we need it here for the AS use-case of shadow servers
+ password = await self.auth_handler.hash(password)
+
user_id = await self.registration_handler.appservice_register(
- username, as_token, password_hash, display_name
+ username, as_token, password, display_name
)
result = await self._create_registration_details(user_id, body)
@@ -736,7 +764,7 @@ class RegisterRestServlet(RestServlet):
(object) params: registration parameters, from which we pull
device_id, initial_device_name and inhibit_login
Returns:
- defer.Deferred: (object) dictionary for response from /register
+ dictionary for response from /register
"""
result = {"user_id": user_id, "home_server": self.hs.hostname}
if not params.get("inhibit_login", False):
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 89002ffbff..18c75738f8 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
import logging
from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.util.stringutils import random_string
from ._base import client_patterns
@@ -60,7 +61,7 @@ class RelationSendServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationSendServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict=event_dict, txn_id=txn_id
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict=event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- return 200, {"event_id": event.event_id}
+ return 200, {"event_id": event_id}
class RelationPaginationServlet(RestServlet):
@@ -130,7 +138,7 @@ class RelationPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -225,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationAggregationPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
@@ -303,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationAggregationGroupPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e15927c4ea..215d619ca1 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
def __init__(self, hs):
- super(ReportEventRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 59529707df..53de97923f 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysNewVersionServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysVersionServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index f357015a70..bf030e0ff4 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,13 +15,14 @@
import logging
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.util import stringutils
from ._base import client_patterns
@@ -52,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomUpgradeRestServlet, self).__init__()
+ super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
- new_version = content["new_version"]
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
if new_version is None:
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION,
)
- new_room_id = await self._room_creation_handler.upgrade_room(
- requester, room_id, new_version
- )
+ try:
+ new_room_id = await self._room_creation_handler.upgrade_room(
+ requester, room_id, new_version
+ )
+ except ShadowBanError:
+ # Generate a random room ID.
+ new_room_id = stringutils.random_string(18)
ret = {"replacement_room": new_room_id}
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index db829f3098..bc4f43639a 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SendToDeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
new file mode 100644
index 0000000000..c866d5151c
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserSharedRoomsServlet(RestServlet):
+ """
+ GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+ """
+
+ PATTERNS = client_patterns(
+ "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+ releases=(), # This is an unstable feature
+ )
+
+ def __init__(self, hs):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.user_directory_active = hs.config.update_user_directory
+
+ async def on_GET(self, request, user_id):
+
+ if not self.user_directory_active:
+ raise SynapseError(
+ code=400,
+ msg="The user directory is disabled on this server. Cannot determine shared rooms.",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ UserID.from_string(user_id)
+
+ requester = await self.auth.get_user_by_req(request)
+ if user_id == requester.user.to_string():
+ raise SynapseError(
+ code=400,
+ msg="You cannot request a list of shared rooms with yourself",
+ errcode=Codes.FORBIDDEN,
+ )
+ rooms = await self.store.get_shared_rooms_for_users(
+ requester.user.to_string(), user_id
+ )
+
+ return 200, {"joined": list(rooms)}
+
+
+def register_servlets(hs, http_server):
+ UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..6779df952f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -16,8 +16,6 @@
import itertools
import logging
-from canonicaljson import json
-
from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
+from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@@ -75,9 +74,10 @@ class SyncRestServlet(RestServlet):
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
def __init__(self, hs):
- super(SyncRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@@ -125,7 +125,7 @@ class SyncRestServlet(RestServlet):
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
- filter_object = json.loads(filter_id)
+ filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
@@ -152,10 +152,9 @@ class SyncRestServlet(RestServlet):
device_id=device_id,
)
+ since_token = None
if since is not None:
- since_token = StreamToken.from_string(since)
- else:
- since_token = None
+ since_token = await StreamToken.from_string(self.store, since)
# send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(user.to_string())
@@ -237,7 +236,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "next_batch": sync_result.next_batch.to_string(),
+ "next_batch": await sync_result.next_batch.to_string(self.store),
}
@staticmethod
@@ -414,7 +413,7 @@ class SyncRestServlet(RestServlet):
result = {
"timeline": {
"events": serialized_timeline,
- "prev_batch": room.timeline.prev_batch.to_string(),
+ "prev_batch": await room.timeline.prev_batch.to_string(self.store),
"limited": room.timeline.limited,
},
"state": {"events": serialized_state},
@@ -426,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index a3f12e8a77..bf3a79db44 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -31,7 +31,7 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs):
- super(TagListServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -56,7 +56,7 @@ class TagServlet(RestServlet):
)
def __init__(self, hs):
- super(TagServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 23709960ad..0c127a1b5f 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
- super(ThirdPartyProtocolsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
- super(ThirdPartyProtocolServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
- super(ThirdPartyUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
- super(ThirdPartyLocationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 83f3b6b70a..79317c74ba 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
- super(TokenRefreshRestServlet, self).__init__()
+ super().__init__()
async def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 6e8300d6a5..5d4be8adaf 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -39,7 +39,7 @@ class UserDirectorySearchRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(UserDirectorySearchRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index b1999d051b..c9b9e7f5ff 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -19,6 +19,7 @@
import logging
import re
+from synapse.api.constants import RoomCreationPreset
from synapse.http.servlet import RestServlet
logger = logging.getLogger(__name__)
@@ -28,9 +29,23 @@ class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
def __init__(self, hs):
- super(VersionsRestServlet, self).__init__()
+ super().__init__()
self.config = hs.config
+ # Calculate these once since they shouldn't change after start-up.
+ self.e2ee_forced_public = (
+ RoomCreationPreset.PUBLIC_CHAT
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+ self.e2ee_forced_private = (
+ RoomCreationPreset.PRIVATE_CHAT
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+ self.e2ee_forced_trusted_private = (
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+
def on_GET(self, request):
return (
200,
@@ -63,6 +78,12 @@ class VersionsRestServlet(RestServlet):
# Tchap does not currently assume this rule for r0.5.0
# XXX: Remove this when it does
"m.lazy_load_members": True,
+ # Implements additional endpoints as described in MSC2666
+ "uk.half-shot.msc2666": True,
+ # Whether new rooms will be set to encrypted or not (based on presets).
+ "io.element.e2ee_forced.public": self.e2ee_forced_public,
+ "io.element.e2ee_forced.private": self.e2ee_forced_private,
+ "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4386eb4e72..b3e4d5612e 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -22,8 +22,6 @@ from os import path
import jinja2
from jinja2 import TemplateNotFound
-from twisted.internet import defer
-
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
@@ -135,7 +133,7 @@ class ConsentResource(DirectServeHtmlResource):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id)
+ u = await self.store.get_user_by_id(qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
new file mode 100644
index 0000000000..0170950bf3
--- /dev/null
+++ b/synapse/rest/health.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+
+class HealthResource(Resource):
+ """A resource that does nothing except return a 200 with a body of `OK`,
+ which can be used as a health check.
+
+ Note: `SynapseRequest._should_log_request` ensures that requests to
+ `/health` do not get logged at INFO.
+ """
+
+ isLeaf = 1
+
+ def render_GET(self, request):
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b"OK"
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9b3f85b306..f843f02454 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,19 +15,19 @@
import logging
from typing import Dict, Set
-from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
+from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource):
- """HTTP resource for retreiving the TLS certificate and NACL signature
+ """HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
@@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource):
Supports individual GET APIs and a bulk query POST API.
- Requsts:
+ Requests:
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
+ # If there is a cache miss, request the missing keys, then recurse (and
+ # ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
- key_json = json.loads(key_json.decode("utf-8"))
+ key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
@@ -223,4 +225,4 @@ class RemoteKey(DirectServeJsonResource):
results = {"server_keys": signed_keys}
- respond_with_json_bytes(request, 200, encode_canonical_json(results))
+ respond_with_json(request, 200, results, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 9a847130c0..6568e61829 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,9 @@
import logging
import os
import urllib
+from typing import Awaitable
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -233,21 +235,21 @@ async def respond_with_responder(
finish_request(request)
-class Responder(object):
+class Responder:
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
- consumer (IConsumer)
+ consumer: The consumer to stream into.
Returns:
- Deferred: Resolves once the response has finished being written
+ Resolves once the response has finished being written
"""
pass
@@ -258,7 +260,7 @@ class Responder(object):
pass
-class FileInfo(object):
+class FileInfo:
"""Details about a requested/uploaded file.
Attributes:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index e25c382c9c..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -33,7 +33,7 @@ def _wrap_in_base_path(func):
return _wrapped
-class MediaFilePaths(object):
+class MediaFilePaths:
"""Describes where files are stored on disk.
Most of the functions have a `*_rel` variant which returns a file path that
@@ -80,7 +80,7 @@ class MediaFilePaths(object):
self, server_name, file_id, width, height, content_type, method
):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths(object):
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+ # Legacy path that was used to store thumbnails previously.
+ # Should be removed after some time, when most of the thumbnails are stored
+ # using the new path.
+ def remote_media_thumbnail_rel_legacy(
+ self, server_name, file_id, width, height, content_type
+ ):
+ top_level_type, sub_type = content_type.split("/")
+ file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ return os.path.join(
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
+ file_name,
+ )
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..e1192b47cd 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
-from typing import Dict, Tuple
+from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
+ Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -51,7 +53,7 @@ from .media_storage import MediaStorage
from .preview_url_resource import PreviewUrlResource
from .storage_provider import StorageProviderWrapper
from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
logger = logging.getLogger(__name__)
@@ -60,7 +62,7 @@ logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
-class MediaRepository(object):
+class MediaRepository:
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
@@ -135,20 +137,26 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
- self, media_type, upload_name, content, content_length, auth_user
- ):
+ self,
+ media_type: str,
+ upload_name: Optional[str],
+ content: IO,
+ content_length: int,
+ auth_user: str,
+ ) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type(str): The content type of the file
- upload_name(str): The name of the file
+ media_type: The content type of the file.
+ upload_name: The name of the file, if provided.
content: A file like object that is the content to store
- content_length(int): The length of the content
- auth_user(str): The user_id of the uploader
+ content_length: The length of the content
+ auth_user: The user_id of the uploader
Returns:
- Deferred[str]: The mxc url of the stored content
+ The mxc url of the stored content
"""
+
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id)
@@ -170,19 +178,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
- async def get_local_media(self, request, media_id, name):
+ async def get_local_media(
+ self, request: Request, media_id: str, name: Optional[str]
+ ) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
- request(twisted.web.http.Request)
- media_id (str): The media ID of the content. (This is the same as
+ request: The incoming request.
+ media_id: The media ID of the content. (This is the same as
the file_id for local content.)
- name (str|None): Optional name that, if specified, will be used as
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@@ -203,20 +212,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
- async def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(
+ self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ ) -> None:
"""Respond to requests for remote media.
Args:
- request(twisted.web.http.Request)
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
- name (str|None): Optional name that, if specified, will be used as
+ request: The incoming request.
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@@ -245,17 +254,16 @@ class MediaRepository(object):
else:
respond_404(request)
- async def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
Returns:
- Deferred[dict]: The media_info of the file
+ The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@@ -278,7 +286,9 @@ class MediaRepository(object):
return media_info
- async def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(
+ self, server_name: str, media_id: str
+ ) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -288,7 +298,7 @@ class MediaRepository(object):
remote server).
Returns:
- Deferred[(Responder, media_info)]
+ A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -319,19 +329,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(
+ self, server_name: str, media_id: str, file_id: str
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
- server_name (str): Originating server
- media_id (str): The media ID of the content (as defined by the
+ server_name: Originating server
+ media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
- file_id (str): Local file ID
+ file_id: Local file ID
Returns:
- Deferred[MediaInfo]
+ The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@@ -449,13 +461,30 @@ class MediaRepository(object):
return t_byte_source
async def generate_local_exact_thumbnail(
- self, media_id, t_width, t_height, t_method, t_type, url_cache
- ):
+ self,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ url_cache: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+ media_id,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -495,14 +524,36 @@ class MediaRepository(object):
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def generate_remote_exact_thumbnail(
- self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+ media_id,
+ server_name,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -548,32 +599,52 @@ class MediaRepository(object):
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def _generate_thumbnails(
- self, server_name, media_id, file_id, media_type, url_cache=False
- ):
+ self,
+ server_name: Optional[str],
+ media_id: str,
+ file_id: str,
+ media_type: str,
+ url_cache: bool = False,
+ ) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
- server_name (str|None): The server name if remote media, else None if local
- media_id (str): The media ID of the content. (This is the same as
+ server_name: The server name if remote media, else None if local
+ media_id: The media ID of the content. (This is the same as
the file_id for local content)
- file_id (str): Local file ID
- media_type (str): The content type of the file
- url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ file_id: Local file ID
+ media_type: The content type of the file
+ url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
- Deferred[dict]: Dict with "width" and "height" keys of original image
+ Dict with "width" and "height" keys of original image or None if the
+ media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
- return
+ return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate thumbnails for remote media %s from %s of type %s: %s",
+ media_id,
+ server_name,
+ media_type,
+ e,
+ )
+ return None
+
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -584,7 +655,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 66bc1c3360..a9586fb0b7 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,13 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
-import inspect
import logging
import os
import shutil
-from typing import Optional
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender
@@ -26,28 +24,39 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+ from .storage_provider import StorageProviderWrapper
logger = logging.getLogger(__name__)
-class MediaStorage(object):
+class MediaStorage:
"""Responsible for storing/fetching files from local sources.
Args:
- hs (synapse.server.Homeserver)
- local_media_directory (str): Base path where we store media on disk
- filepaths (MediaFilePaths)
- storage_providers ([StorageProvider]): List of StorageProvider that are
- used to fetch and store files.
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
"""
- def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProviderWrapper"],
+ ):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
- async def store_file(self, source, file_info: FileInfo) -> str:
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
@@ -69,7 +78,7 @@ class MediaStorage(object):
return fname
@contextlib.contextmanager
- def store_into_file(self, file_info):
+ def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -85,7 +94,7 @@ class MediaStorage(object):
error.
Args:
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Example:
@@ -105,11 +114,7 @@ class MediaStorage(object):
async def finish():
for provider in self.storage_providers:
- # store_file is supposed to return an Awaitable, but guard
- # against improper implementations.
- result = provider.store_file(path, file_info)
- if inspect.isawaitable(result):
- await result
+ await provider.store_file(path, file_info)
finished_called[0] = True
@@ -136,21 +141,34 @@ class MediaStorage(object):
Returns:
Returns a Responder if the file was found, otherwise None.
"""
+ paths = [self._file_info_to_path(file_info)]
- path = self._file_info_to_path(file_info)
- local_path = os.path.join(self.local_media_directory, path)
- if os.path.exists(local_path):
- return FileResponder(open(local_path, "rb"))
+ # fallback for remote thumbnails with no method in the filename
+ if file_info.thumbnail and file_info.server_name:
+ paths.append(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ )
+
+ for path in paths:
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ logger.debug("responding with local file %s", local_path)
+ return FileResponder(open(local_path, "rb"))
+ logger.debug("local file %s did not exist", local_path)
for provider in self.storage_providers:
- res = provider.fetch(path, file_info)
- # Fetch is supposed to return an Awaitable, but guard against
- # improper implementations.
- if inspect.isawaitable(res):
- res = await res
- if res:
- logger.debug("Streaming %s from %s", path, provider)
- return res
+ for path in paths:
+ res = await provider.fetch(path, file_info) # type: Any
+ if res:
+ logger.debug("Streaming %s from %s", path, provider)
+ return res
+ logger.debug("%s not found on %s", path, provider)
return None
@@ -169,16 +187,26 @@ class MediaStorage(object):
if os.path.exists(local_path):
return local_path
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return legacy_local_path
+
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = provider.fetch(path, file_info)
- # Fetch is supposed to return an Awaitable, but guard against
- # improper implementations.
- if inspect.isawaitable(res):
- res = await res
+ res = await provider.fetch(path, file_info) # type: Any
if res:
with res:
consumer = BackgroundFileConsumer(
@@ -190,17 +218,11 @@ class MediaStorage(object):
raise Exception("file could not be found")
- def _file_info_to_path(self, file_info):
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
-
- Args:
- file_info (FileInfo)
-
- Returns:
- str
"""
if file_info.url_cache:
if file_info.thumbnail:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 13d1a6d2ed..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,9 +27,7 @@ from typing import Dict, Optional
from urllib import parse as urlparse
import attr
-from canonicaljson import json
-from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
@@ -43,6 +41,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.stringutils import random_string
@@ -103,7 +102,7 @@ for endpoint, globs in _oembed_globs.items():
_oembed_patterns[re.compile(pattern)] = endpoint
-@attr.s
+@attr.s(slots=True)
class OEmbedResult:
# Either HTML content or URL must be provided.
html = attr.ib(type=Optional[str])
@@ -228,19 +227,19 @@ class PreviewUrlResource(DirectServeJsonResource):
else:
logger.info("Returning cached response")
- og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
+ og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
- async def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
- url (str):
- user (str):
- ts (int):
+ url: The URL to preview.
+ user: The user requesting the preview.
+ ts: The timestamp requested for the preview.
Returns:
- Deferred[bytes]: json-encoded og data
+ json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
@@ -355,7 +354,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Calculated OG for %s as %s", url, og)
- jsonog = json.dumps(og)
+ jsonog = json_encoder.encode(og)
# store OG in history-aware DB cache
await self.store.store_url_cache(
@@ -451,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url, user):
+ async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -461,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url
+ url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -521,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
- etag = headers["ETag"][0] if "ETag" in headers else None
+ etag = (
+ headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ )
else:
- html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ # we can only get here if we did an oembed request and have an oembed_result.html
+ assert oembed_result.html is not None
+ assert oembed_url is not None
+
+ html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
@@ -586,7 +591,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry")
- if not (await self.store.db.updates.has_completed_background_updates()):
+ if not (await self.store.db_pool.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..18c9ed48d6 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,65 +13,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import os
import shutil
-
-from twisted.internet import defer
+from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
-class StorageProvider(object):
+class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
-
- Returns:
- Deferred
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
"""
- pass
- def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
Returns:
- Deferred(Responder): Returns a Responder if the provider has the file,
- otherwise returns None.
+ Returns a Responder if the provider has the file, otherwise returns None.
"""
- pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
- backend (StorageProvider)
- store_local (bool): Whether to store new local files or not.
- store_synchronous (bool): Whether to wait for file to be successfully
+ backend: The storage provider to wrap.
+ store_local: Whether to store new local files or not.
+ store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
- store_remote (bool): Whether remote media should be uploaded
+ store_remote: Whether remote media should be uploaded
"""
- def __init__(self, backend, store_local, store_synchronous, store_remote):
+ def __init__(
+ self,
+ backend: StorageProvider,
+ store_local: bool,
+ store_synchronous: bool,
+ store_remote: bool,
+ ):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@@ -80,28 +81,38 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
- return defer.succeed(None)
+ return None
if file_info.server_name and not self.store_remote:
- return defer.succeed(None)
+ return None
if self.store_synchronous:
- return self.backend.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
else:
# TODO: Handle errors.
- def store():
+ async def store():
try:
- return self.backend.store_file(path, file_info)
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
except Exception:
logger.exception("Error storing file")
run_in_background(store)
- return defer.succeed(None)
+ return None
- def fetch(self, path, file_info):
- return self.backend.fetch(path, file_info)
+ async def fetch(self, path, file_info):
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.fetch(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +131,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +141,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return defer_to_thread(
+ return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- def fetch(self, path, file_info):
+ async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _select_or_generate_remote_thumbnail(
self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 7126997134..32a8e4f960 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
import logging
from io import BytesIO
-from PIL import Image as Image
+from PIL import Image
logger = logging.getLogger(__name__)
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
}
-class Thumbnailer(object):
+class ThumbnailError(Exception):
+ """An error occurred generating a thumbnail."""
+
+
+class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
- self.image = Image.open(input_path)
+ try:
+ self.image = Image.open(input_path)
+ except OSError as e:
+ # If an error occurs opening the image, a thumbnail won't be able to
+ # be generated.
+ raise ThumbnailError from e
+
self.width, self.height = self.image.size
self.transpose_method = None
try:
@@ -73,7 +83,7 @@ class Thumbnailer(object):
Args:
max_width: The largest possible width.
- max_height: The larget possible height.
+ max_height: The largest possible height.
"""
if max_width * self.height < max_height * self.width:
@@ -107,7 +117,7 @@ class Thumbnailer(object):
Args:
max_width: The largest possible width.
- max_height: The larget possible height.
+ max_height: The largest possible height.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 3ebf7a68e6..d76f7389e1 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -63,6 +63,10 @@ class UploadResource(DirectServeJsonResource):
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
)
+ # If the name is falsey (e.g. an empty byte string) ensure it is None.
+ else:
+ upload_name = None
+
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index c10188a5d7..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.python import failure
-from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeHtmlResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource
class SAML2ResponseResource(DirectServeHtmlResource):
@@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource):
def __init__(self, hs):
super().__init__()
self._saml_handler = hs.get_saml_handler()
- self._error_html_template = hs.config.saml2.saml2_error_html_template
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
- f = failure.Failure(
- SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+ self._saml_handler._render_error(
+ request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
- return_html_error(f, request, self._error_html_template)
async def _async_render_POST(self, request):
- try:
- await self._saml_handler.handle_saml_response(request)
- except Exception:
- f = failure.Failure()
- return_html_error(f, request, self._error_html_template)
+ await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/client/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
new file mode 100644
index 0000000000..9e4fbc0cbd
--- /dev/null
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.errors import ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.server import DirectServeHtmlResource
+from synapse.http.servlet import parse_string
+from synapse.util.stringutils import assert_valid_client_secret
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
+ """Handles 3PID validation token submission
+
+ This resource gets mounted under /_synapse/client/password_reset/email/submit_token
+ """
+
+ isLeaf = 1
+
+ def __init__(self, hs: "HomeServer"):
+ """
+ Args:
+ hs: server
+ """
+ super().__init__()
+
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ self._local_threepid_handling_disabled_due_to_email_config = (
+ hs.config.local_threepid_handling_disabled_due_to_email_config
+ )
+ self._confirmation_email_template = (
+ hs.config.email_password_reset_template_confirmation_html
+ )
+ self._email_password_reset_template_success_html = (
+ hs.config.email_password_reset_template_success_html_content
+ )
+ self._failure_email_template = (
+ hs.config.email_password_reset_template_failure_html
+ )
+
+ # This resource should not be mounted if threepid behaviour is not LOCAL
+ assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+
+ async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
+
+ # Show a confirmation page, just in case someone accidentally clicked this link when
+ # they didn't mean to
+ template_vars = {
+ "sid": sid,
+ "token": token,
+ "client_secret": client_secret,
+ }
+ return (
+ 200,
+ self._confirmation_email_template.render(**template_vars).encode("utf-8"),
+ )
+
+ async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]:
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warning(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ next_link_bytes = next_link.encode("utf-8")
+ request.setHeader("Location", next_link_bytes)
+ return (
+ 302,
+ (
+ b'You are being redirected to <a src="%s">%s</a>.'
+ % (next_link_bytes, next_link_bytes)
+ ),
+ )
+
+ # Otherwise show the success template
+ html_bytes = self._email_password_reset_template_success_html.encode(
+ "utf-8"
+ )
+ status_code = 200
+ except ThreepidValidationError as e:
+ status_code = e.code
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html_bytes = self._failure_email_template.render(**template_vars).encode(
+ "utf-8"
+ )
+
+ return status_code, html_bytes
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 20177b44e7..f591cc6c5c 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,17 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
from twisted.web.resource import Resource
from synapse.http.server import set_cors_headers
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
-class WellKnownBuilder(object):
+class WellKnownBuilder:
"""Utility to construct the well-known response
Args:
@@ -67,4 +67,4 @@ class WellKnownResource(Resource):
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
- return json.dumps(r).encode("utf-8")
+ return json_encoder.encode(r).encode("utf-8")
|