diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3adc576124..e04af705eb 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import JsonResource
+from typing import TYPE_CHECKING
+
+from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
@@ -57,6 +59,9 @@ from synapse.rest.client import (
voip,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
- def register_servlets(client_resource, hs):
+ def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)
# Deprecated in r0
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 5715190a78..a6fa03c90f 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore()
async def on_GET(
- self, request: SynapseRequest, user_id, device_id: str
+ self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ad83d4b54c..8f781f745f 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -125,7 +125,7 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- search_term = parse_string(request, "search_term")
+ search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
raise SynapseError(
400,
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index f5a38c2670..19f84f33f2 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, json_resource: HttpServer):
+ def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index c1a1ba645e..681e491826 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {}
self.hs = hs
- def _clear_old_nonces(self):
+ def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index ed96978448..bf14ec384e 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,6 +14,7 @@
import logging
import re
+from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request
@@ -42,25 +43,25 @@ logger = logging.getLogger(__name__)
class RoomBatchSendEventRestServlet(RestServlet):
"""
- API endpoint which can insert a chunk of events historically back in time
+ API endpoint which can insert a batch of events historically back in time
next to the given `prev_event`.
- `chunk_id` comes from `next_chunk_id `in the response of the batch send
- endpoint and is derived from the "insertion" events added to each chunk.
+ `batch_id` comes from `next_batch_id `in the response of the batch send
+ endpoint and is derived from the "insertion" events added to each batch.
It's not required for the first batch send.
`state_events_at_start` is used to define the historical state events
needed to auth the events like join events. These events will float
outside of the normal DAG as outlier's and won't be visible in the chat
- history which also allows us to insert multiple chunks without having a bunch
- of `@mxid joined the room` noise between each chunk.
+ history which also allows us to insert multiple batches without having a bunch
+ of `@mxid joined the room` noise between each batch.
- `events` is chronological chunk/list of events you want to insert.
- There is a reverse-chronological constraint on chunks so once you insert
+ `events` is chronological list of events you want to insert.
+ There is a reverse-chronological constraint on batches so once you insert
some messages, you can only insert older ones after that.
- tldr; Insert chunks from your most recent history -> oldest history.
+ tldr; Insert batches from your most recent history -> oldest history.
- POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event=<eventID>&chunk_id=<chunkID>
+ POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event_id=<eventID>&batch_id=<batchID>
{
"events": [ ... ],
"state_events_at_start": [ ... ]
@@ -128,7 +129,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self, sender: str, room_id: str, origin_server_ts: int
) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
- and a random chunk ID.
+ and a random batch ID.
Args:
sender: The event author MXID
@@ -139,13 +140,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
The new event dictionary to insert.
"""
- next_chunk_id = random_string(8)
+ next_batch_id = random_string(8)
insertion_event = {
"type": EventTypes.MSC2716_INSERTION,
"sender": sender,
"room_id": room_id,
"content": {
- EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+ EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
EventContentFields.MSC2716_HISTORICAL: True,
},
"origin_server_ts": origin_server_ts,
@@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if not requester.app_service:
raise AuthError(
- 403,
+ HTTPStatus.FORBIDDEN,
"Only application services can use the /batchsend endpoint",
)
@@ -187,24 +188,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"])
assert request.args is not None
- prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
- chunk_id_from_query = parse_string(request, "chunk_id")
+ prev_event_ids_from_query = parse_strings_from_args(
+ request.args, "prev_event_id"
+ )
+ batch_id_from_query = parse_string(request, "batch_id")
- if prev_events_from_query is None:
+ if prev_event_ids_from_query is None:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"prev_event query parameter is required when inserting historical messages back in time",
errcode=Codes.MISSING_PARAM,
)
- # For the event we are inserting next to (`prev_events_from_query`),
+ # For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
(
most_recent_prev_event_id,
_,
- ) = await self.store.get_max_depth_of(prev_events_from_query)
+ ) = await self.store.get_max_depth_of(prev_event_ids_from_query)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id
@@ -213,7 +216,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
- state_events_at_start = []
+ state_event_ids_at_start = []
for state_event in body["state_events_at_start"]:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -279,27 +282,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
event_id = event.event_id
- state_events_at_start.append(event_id)
+ state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
events_to_create = body["events"]
inherited_depth = await self._inherit_depth_from_prev_ids(
- prev_events_from_query
+ prev_event_ids_from_query
)
- # Figure out which chunk to connect to. If they passed in
- # chunk_id_from_query let's use it. The chunk ID passed in comes
- # from the chunk_id in the "insertion" event from the previous chunk.
- last_event_in_chunk = events_to_create[-1]
- chunk_id_to_connect_to = chunk_id_from_query
+ # Figure out which batch to connect to. If they passed in
+ # batch_id_from_query let's use it. The batch ID passed in comes
+ # from the batch_id in the "insertion" event from the previous batch.
+ last_event_in_batch = events_to_create[-1]
+ batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None
- if chunk_id_from_query:
+ if batch_id_from_query:
# All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of
- # the chunk later.
+ # the batch later.
prev_event_ids = [fake_prev_event_id]
- # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+
+ # Verify the batch_id_from_query corresponds to an actual insertion event
+ # and have the batch connected.
+ corresponding_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
+ )
+ if corresponding_insertion_event_id is None:
+ raise SynapseError(
+ 400,
+ "No insertion event corresponds to the given ?batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
pass
# Otherwise, create an insertion event to act as a starting point.
#
@@ -309,12 +323,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
# an insertion event), in which case we just create a new insertion event
# that can then get pointed to by a "marker" event later.
else:
- prev_event_ids = prev_events_from_query
+ prev_event_ids = prev_event_ids_from_query
base_insertion_event_dict = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
- origin_server_ts=last_event_in_chunk["origin_server_ts"],
+ origin_server_ts=last_event_in_batch["origin_server_ts"],
)
base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@@ -333,38 +347,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
- chunk_id_to_connect_to = base_insertion_event["content"][
- EventContentFields.MSC2716_NEXT_CHUNK_ID
+ batch_id_to_connect_to = base_insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
]
- # Connect this current chunk to the insertion event from the previous chunk
- chunk_event = {
- "type": EventTypes.MSC2716_CHUNK,
+ # Connect this current batch to the insertion event from the previous batch
+ batch_event = {
+ "type": EventTypes.MSC2716_BATCH,
"sender": requester.user.to_string(),
"room_id": room_id,
"content": {
- EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to,
+ EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
EventContentFields.MSC2716_HISTORICAL: True,
},
- # Since the chunk event is put at the end of the chunk,
+ # Since the batch event is put at the end of the batch,
# where the newest-in-time event is, copy the origin_server_ts from
# the last event we're inserting
- "origin_server_ts": last_event_in_chunk["origin_server_ts"],
+ "origin_server_ts": last_event_in_batch["origin_server_ts"],
}
- # Add the chunk event to the end of the chunk (newest-in-time)
- events_to_create.append(chunk_event)
+ # Add the batch event to the end of the batch (newest-in-time)
+ events_to_create.append(batch_event)
- # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
- # event in the chunk) so the next chunk can be connected to this one.
+ # Add an "insertion" event to the start of each batch (next to the oldest-in-time
+ # event in the batch) so the next batch can be connected to this one.
insertion_event = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
- # Since the insertion event is put at the start of the chunk,
+ # Since the insertion event is put at the start of the batch,
# where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"],
)
- # Prepend the insertion event to the start of the chunk (oldest-in-time)
+ # Prepend the insertion event to the start of the batch (oldest-in-time)
events_to_create = [insertion_event] + events_to_create
event_ids = []
@@ -424,20 +438,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
context=context,
)
- # Add the base_insertion_event to the bottom of the list we return
- if base_insertion_event is not None:
- event_ids.append(base_insertion_event.event_id)
+ insertion_event_id = event_ids[0]
+ batch_event_id = event_ids[-1]
+ historical_event_ids = event_ids[1:-1]
- return 200, {
- "state_events": state_events_at_start,
- "events": event_ids,
- "next_chunk_id": insertion_event["content"][
- EventContentFields.MSC2716_NEXT_CHUNK_ID
+ response_dict = {
+ "state_event_ids": state_event_ids_at_start,
+ "event_ids": historical_event_ids,
+ "next_batch_id": insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
],
+ "insertion_event_id": insertion_event_id,
+ "batch_event_id": batch_event_id,
}
+ if base_insertion_event is not None:
+ response_dict["base_insertion_event_id"] = base_insertion_event.event_id
+
+ return HTTPStatus.OK, response_dict
def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
- return 501, "Not implemented"
+ return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
def on_PUT(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 11f7320832..06e0fbde22 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -17,17 +17,22 @@ import logging
from hashlib import sha256
from http import HTTPStatus
from os import path
-from typing import Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List
import jinja2
from jinja2 import TemplateNotFound
+from twisted.web.server import Request
+
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"
@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): homeserver
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- async def _async_render_GET(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- async def _async_render_POST(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
@@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("success.html not found")
- def _render_template(self, request, template_name, **template_args):
+ def _render_template(
+ self, request: Request, template_name: str, **template_args: Any
+ ) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
@@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)
- def _check_hash(self, userid, userhmac):
+ def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
- userid (unicode):
- userhmac (bytes):
+ userid:
+ userhmac:
Raises:
SynapseError if the hash doesn't match
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
index 4487b54abf..78df7af2cf 100644
--- a/synapse/rest/health.py
+++ b/synapse/rest/health.py
@@ -13,6 +13,7 @@
# limitations under the License.
from twisted.web.resource import Resource
+from twisted.web.server import Request
class HealthResource(Resource):
@@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index c6c63073ea..7f8c1de1ff 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class KeyApiV2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 25f6eb842f..ebe243bcfd 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -58,18 +63,18 @@ class LocalKey(Resource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
- def update_response_body(self, time_now_msec):
+ def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
- def response_json_object(self):
+ def response_json_object(self) -> JsonDict:
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
@@ -94,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 744360e5fd..d8fd7938a4 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
+from twisted.web.server import Request
+
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
@@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
+ assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
@@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(
+ self,
+ request: Request,
+ query: JsonDict,
+ query_remote_on_cache_miss: bool = False,
+ ) -> None:
logger.info("Handling query for keys %r", query)
store_queries = []
@@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), results in cached.items():
- results = [(result["ts_added_ms"], result) for result in results]
+ for (server_name, key_id, _), key_results in cached.items():
+ results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
@@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
- results = {"server_keys": signed_keys}
+ response = {"server_keys": signed_keys}
- respond_with_json(request, 200, results, canonical_json=True)
+ respond_with_json(request, 200, response, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 90364ebcf7..7c881f2bdb 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,7 +16,10 @@
import logging
import os
import urllib
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
@@ -120,7 +123,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any.
"""
- def _quote(x):
+ def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types.
@@ -280,51 +283,74 @@ class Responder:
"""
pass
- def __enter__(self):
+ def __enter__(self) -> None:
pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
pass
-class FileInfo:
- """Details about a requested/uploaded file.
-
- Attributes:
- server_name (str): The server name where the media originated from,
- or None if local.
- file_id (str): The local ID of the file. For local files this is the
- same as the media_id
- url_cache (bool): If the file is for the url preview cache
- thumbnail (bool): Whether the file is a thumbnail or not.
- thumbnail_width (int)
- thumbnail_height (int)
- thumbnail_method (str)
- thumbnail_type (str): Content type of thumbnail, e.g. image/png
- thumbnail_length (int): The size of the media file, in bytes.
- """
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+ """Details about a generated thumbnail."""
- def __init__(
- self,
- server_name,
- file_id,
- url_cache=False,
- thumbnail=False,
- thumbnail_width=None,
- thumbnail_height=None,
- thumbnail_method=None,
- thumbnail_type=None,
- thumbnail_length=None,
- ):
- self.server_name = server_name
- self.file_id = file_id
- self.url_cache = url_cache
- self.thumbnail = thumbnail
- self.thumbnail_width = thumbnail_width
- self.thumbnail_height = thumbnail_height
- self.thumbnail_method = thumbnail_method
- self.thumbnail_type = thumbnail_type
- self.thumbnail_length = thumbnail_length
+ width: int
+ height: int
+ method: str
+ # Content type of thumbnail, e.g. image/png
+ type: str
+ # The size of the media file, in bytes.
+ length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+ """Details about a requested/uploaded file."""
+
+ # The server name where the media originated from, or None if local.
+ server_name: Optional[str]
+ # The local ID of the file. For local files this is the same as the media_id
+ file_id: str
+ # If the file is for the url preview cache
+ url_cache: bool = False
+ # Whether the file is a thumbnail or not.
+ thumbnail: Optional[ThumbnailInfo] = None
+
+ # The below properties exist to maintain compatibility with third-party modules.
+ @property
+ def thumbnail_width(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.width
+
+ @property
+ def thumbnail_height(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.height
+
+ @property
+ def thumbnail_method(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.method
+
+ @property
+ def thumbnail_type(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.type
+
+ @property
+ def thumbnail_length(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 09531ebf54..39bbe4e874 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,7 @@
import functools
import os
import re
-from typing import Callable, List
+from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
"""
@functools.wraps(func)
- def _wrapped(self, *args, **kwargs):
+ def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
@@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
- ):
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0f5ce41ff8..50e4c9e29f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -32,6 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
+from synapse.config.repository import ThumbnailRequirement
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@@ -42,6 +44,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
+ ThumbnailInfo,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -113,7 +116,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
- def _start_update_recently_accessed(self):
+ def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
)
@@ -210,7 +213,7 @@ class MediaRepository:
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
- file_info = FileInfo(None, media_id, url_cache=url_cache)
+ file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
@@ -468,7 +471,9 @@ class MediaRepository:
return media_info
- def _get_thumbnail_requirements(self, media_type):
+ def _get_thumbnail_requirements(
+ self, media_type: str
+ ) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";")
if scpos > 0:
media_type = media_type[:scpos]
@@ -514,7 +519,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
- url_cache: Optional[str],
+ url_cache: bool,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
@@ -548,11 +553,12 @@ class MediaRepository:
server_name=None,
file_id=media_id,
url_cache=url_cache,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -585,7 +591,7 @@ class MediaRepository:
t_type: str,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
- FileInfo(server_name, file_id, url_cache=False)
+ FileInfo(server_name, file_id)
)
try:
@@ -616,11 +622,12 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -742,12 +749,13 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
url_cache=url_cache,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 56cdc1b4ed..01fada8fb5 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -15,7 +15,20 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+from types import TracebackType
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Generator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+)
import attr
@@ -83,12 +96,14 @@ class MediaStorage:
return fname
- async def write_to_file(self, source: IO, output: IO):
+ async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
- def store_into_file(self, file_info: FileInfo):
+ def store_into_file(
+ self, file_info: FileInfo
+ ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -125,7 +140,7 @@ class MediaStorage:
try:
with open(fname, "wb") as f:
- async def finish():
+ async def finish() -> None:
# Ensure that all writes have been flushed and close the
# file.
f.flush()
@@ -176,9 +191,9 @@ class MediaStorage:
self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
)
@@ -220,9 +235,9 @@ class MediaStorage:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
@@ -255,10 +270,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
@@ -267,10 +282,10 @@ class MediaStorage:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id
@@ -279,10 +294,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.local_media_filepath_rel(file_info.file_id)
@@ -315,7 +330,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer)
)
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
self.open_file.close()
@@ -339,7 +359,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock)
path = attr.ib(type=str)
- async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2e6706dbfa..8b74e72655 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import urllib.parse
from typing import TYPE_CHECKING, Optional
import attr
from synapse.http.client import SimpleHttpClient
+from synapse.types import JsonDict
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -24,18 +27,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class OEmbedResult:
- # Either HTML content or URL must be provided.
- html: Optional[str]
- url: Optional[str]
- title: Optional[str]
- # Number of seconds to cache the content.
- cache_age: int
-
-
-class OEmbedError(Exception):
- """An error occurred processing the oEmbed object."""
+ # The Open Graph result (converted from the oEmbed result).
+ open_graph_result: JsonDict
+ # Number of seconds to cache the content, according to the oEmbed response.
+ #
+ # This will be None if no cache-age is provided in the oEmbed response (or
+ # if the oEmbed response cannot be turned into an Open Graph response).
+ cache_age: Optional[int]
class OEmbedProvider:
@@ -81,75 +81,106 @@ class OEmbedProvider:
"""
for url_pattern, endpoint in self._oembed_patterns.items():
if url_pattern.fullmatch(url):
- return endpoint
+ # TODO Specify max height / width.
+
+ # Note that only the JSON format is supported, some endpoints want
+ # this in the URL, others want it as an argument.
+ endpoint = endpoint.replace("{format}", "json")
+
+ args = {"url": url, "format": "json"}
+ query_str = urllib.parse.urlencode(args, True)
+ return f"{endpoint}?{query_str}"
# No match.
return None
- async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"""
- Request content from an oEmbed endpoint.
+ Parse the oEmbed response into an Open Graph response.
Args:
- endpoint: The oEmbed API endpoint.
- url: The URL to pass to the API.
+ url: The URL which is being previewed (not the one which was
+ requested).
+ raw_body: The oEmbed response as JSON encoded as bytes.
Returns:
- An object representing the metadata returned.
-
- Raises:
- OEmbedError if fetching or parsing of the oEmbed information fails.
+ json-encoded Open Graph data
"""
- try:
- logger.debug("Trying to get oEmbed content for url '%s'", url)
- # Note that only the JSON format is supported, some endpoints want
- # this in the URL, others want it as an argument.
- endpoint = endpoint.replace("{format}", "json")
-
- result = await self._client.get_json(
- endpoint,
- # TODO Specify max height / width.
- args={"url": url, "format": "json"},
- )
+ try:
+ # oEmbed responses *must* be UTF-8 according to the spec.
+ oembed = json_decoder.decode(raw_body.decode("utf-8"))
# Ensure there's a version of 1.0.
- if result.get("version") != "1.0":
- raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
- oembed_type = result.get("type")
+ oembed_version = oembed["version"]
+ if oembed_version != "1.0":
+ raise RuntimeError(f"Invalid version: {oembed_version}")
# Ensure the cache age is None or an int.
- cache_age = result.get("cache_age")
+ cache_age = oembed.get("cache_age")
if cache_age:
cache_age = int(cache_age)
- oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+ # The results.
+ open_graph_response = {"og:title": oembed.get("title")}
- # HTML content.
+ # If a thumbnail exists, use it. Note that dimensions will be calculated later.
+ if "thumbnail_url" in oembed:
+ open_graph_response["og:image"] = oembed["thumbnail_url"]
+
+ # Process each type separately.
+ oembed_type = oembed["type"]
if oembed_type == "rich":
- oembed_result.html = result.get("html")
- return oembed_result
+ calc_description_and_urls(open_graph_response, oembed["html"])
- if oembed_type == "photo":
- oembed_result.url = result.get("url")
- return oembed_result
+ elif oembed_type == "photo":
+ # If this is a photo, use the full image, not the thumbnail.
+ open_graph_response["og:image"] = oembed["url"]
- # TODO Handle link and video types.
+ else:
+ raise RuntimeError(f"Unknown oEmbed type: {oembed_type}")
- if "thumbnail_url" in result:
- oembed_result.url = result.get("thumbnail_url")
- return oembed_result
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}")
+ open_graph_response = {}
+ cache_age = None
- raise OEmbedError("Incompatible oEmbed information.")
+ return OEmbedResult(open_graph_response, cache_age)
- except OEmbedError as e:
- # Trap OEmbedErrors first so we can directly re-raise them.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- raise
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
- raise OEmbedError() from e
+def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None:
+ """
+ Calculate description for an HTML document.
+
+ This uses lxml to convert the HTML document into plaintext. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ open_graph_response: The current Open Graph summary. This is updated with additional fields.
+ html_body: The HTML document, as bytes.
+
+ Returns:
+ The summary
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not html_body:
+ return
+
+ from lxml import etree
+
+ # Create an HTML parser. If this fails, log and return no metadata.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(html_body, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return
+
+ from synapse.rest.media.v1.preview_url_resource import _calc_description
+
+ description = _calc_description(tree)
+ if description:
+ open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f108da05db..0a0b476d2b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,6 +27,7 @@ from urllib import parse as urlparse
import attr
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
@@ -43,7 +44,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
-from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -72,6 +73,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
+ONE_DAY = 24 * ONE_HOUR
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -254,10 +256,19 @@ class PreviewUrlResource(DirectServeJsonResource):
og = og.encode("utf8")
return og
- media_info = await self._download_url(url, user)
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._oembed.get_oembed_url(url)
+ if oembed_url:
+ url_to_download = oembed_url
+
+ media_info = await self._download_url(url_to_download, user)
logger.debug("got media_info of '%s'", media_info)
+ # The number of milliseconds that the response should be considered valid.
+ expiration_ms = media_info.expires
+
if _is_media(media_info.media_type):
file_id = media_info.filesystem_id
dims = await self.media_repo._generate_thumbnails(
@@ -287,34 +298,22 @@ class PreviewUrlResource(DirectServeJsonResource):
encoding = get_html_media_encoding(body, media_info.media_type)
og = decode_and_calc_og(body, media_info.uri, encoding)
- # pre-cache the image for posterity
- # FIXME: it might be cleaner to use the same flow as the main /preview_url
- # request itself and benefit from the same caching etc. But for now we
- # just rely on the caching on the master request to speed things up.
- if "og:image" in og and og["og:image"]:
- image_info = await self._download_url(
- _rebase_url(og["og:image"], media_info.uri), user
- )
+ await self._precache_image_url(user, media_info, og)
+
+ elif oembed_url and _is_json(media_info.media_type):
+ # Handle an oEmbed response.
+ with open(media_info.filename, "rb") as file:
+ body = file.read()
+
+ oembed_response = self._oembed.parse_oembed_response(media_info.uri, body)
+ og = oembed_response.open_graph_result
+
+ # Use the cache age from the oEmbed result, instead of the HTTP response.
+ if oembed_response.cache_age is not None:
+ expiration_ms = oembed_response.cache_age
+
+ await self._precache_image_url(user, media_info, og)
- if _is_media(image_info.media_type):
- # TODO: make sure we don't choke on white-on-transparent images
- file_id = image_info.filesystem_id
- dims = await self.media_repo._generate_thumbnails(
- None, file_id, file_id, image_info.media_type, url_cache=True
- )
- if dims:
- og["og:image:width"] = dims["width"]
- og["og:image:height"] = dims["height"]
- else:
- logger.warning("Couldn't get dims for %s", og["og:image"])
-
- og[
- "og:image"
- ] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
- og["og:image:type"] = image_info.media_type
- og["matrix:image:size"] = image_info.media_length
- else:
- del og["og:image"]
else:
logger.warning("Failed to find any OG data in %s", url)
og = {}
@@ -335,12 +334,15 @@ class PreviewUrlResource(DirectServeJsonResource):
jsonog = json_encoder.encode(og)
+ # Cap the amount of time to consider a response valid.
+ expiration_ms = min(expiration_ms, ONE_DAY)
+
# store OG in history-aware DB cache
await self.store.store_url_cache(
url,
media_info.response_code,
media_info.etag,
- media_info.expires + media_info.created_ts_ms,
+ media_info.created_ts_ms + expiration_ms,
jsonog,
media_info.filesystem_id,
media_info.created_ts_ms,
@@ -357,88 +359,52 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- # If this URL can be accessed via oEmbed, use that instead.
- url_to_download: Optional[str] = url
- oembed_url = self._oembed.get_oembed_url(url)
- if oembed_url:
- # The result might be a new URL to download, or it might be HTML content.
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
- if oembed_result.url:
- url_to_download = oembed_result.url
- elif oembed_result.html:
- url_to_download = None
- except OEmbedError:
- # If an error occurs, try doing a normal preview.
- pass
+ logger.debug("Trying to get preview for url '%s'", url)
+ length, headers, uri, code = await self.client.get_file(
+ url,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url, e)
- if url_to_download:
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- try:
- logger.debug("Trying to get preview for url '%s'", url_to_download)
- length, headers, uri, code = await self.client.get_file(
- url_to_download,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url_to_download, e)
-
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
-
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
- download_name = get_filename_from_headers(headers)
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- expires = ONE_HOUR
- etag = (
- headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
- )
- else:
- # we can only get here if we did an oembed request and have an oembed_result.html
- assert oembed_result.html is not None
- assert oembed_url is not None
-
- html_bytes = oembed_result.html.encode("utf-8")
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- f.write(html_bytes)
- await finish()
-
- media_type = "text/html"
- download_name = oembed_result.title
- length = len(html_bytes)
- # If a specific cache age was not given, assume 1 hour.
- expires = oembed_result.cache_age or ONE_HOUR
- uri = oembed_url
- code = 200
- etag = None
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
try:
time_now_ms = self.clock.time_msec()
@@ -473,7 +439,47 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag,
)
- def _start_expire_url_cache_data(self):
+ async def _precache_image_url(
+ self, user: str, media_info: MediaInfo, og: JsonDict
+ ) -> None:
+ """
+ Pre-cache the image (if one exists) for posterity
+
+ Args:
+ user: The user requesting the preview.
+ media_info: The media being previewed.
+ og: The Open Graph dictionary. This is modified with image information.
+ """
+ # If there's no image or it is blank, there's nothing to do.
+ if "og:image" not in og or not og["og:image"]:
+ return
+
+ # FIXME: it might be cleaner to use the same flow as the main /preview_url
+ # request itself and benefit from the same caching etc. But for now we
+ # just rely on the caching on the master request to speed things up.
+ image_info = await self._download_url(
+ _rebase_url(og["og:image"], media_info.uri), user
+ )
+
+ if _is_media(image_info.media_type):
+ # TODO: make sure we don't choke on white-on-transparent images
+ file_id = image_info.filesystem_id
+ dims = await self.media_repo._generate_thumbnails(
+ None, file_id, file_id, image_info.media_type, url_cache=True
+ )
+ if dims:
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
+ else:
+ logger.warning("Couldn't get dims for %s", og["og:image"])
+
+ og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
+ og["og:image:type"] = image_info.media_type
+ og["matrix:image:size"] = image_info.media_length
+ else:
+ del og["og:image"]
+
+ def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
)
@@ -526,7 +532,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * ONE_HOUR
+ expire_before = now - 2 * ONE_DAY
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
@@ -668,7 +674,18 @@ def decode_and_calc_og(
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
- # suck our tree into lxml and define our OG response.
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to search the HTML document for Open Graph data.
+
+ Args:
+ tree: The parsed HTML document.
+ media_url: The URI used to download the body.
+
+ Returns:
+ The Open Graph response as a dictionary.
+ """
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
@@ -742,35 +759,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
if meta_description:
og["og:description"] = meta_description[0]
else:
- # grab any text nodes which are inside the <body/> tag...
- # unless they are within an HTML5 semantic markup tag...
- # <header/>, <nav/>, <aside/>, <footer/>
- # ...or if they are within a <script/> or <style/> tag.
- # This is a very very very coarse approximation to a plain text
- # render of the page.
-
- # We don't just use XPATH here as that is slow on some machines.
-
- from lxml import etree
-
- TAGS_TO_REMOVE = (
- "header",
- "nav",
- "aside",
- "footer",
- "script",
- "noscript",
- "style",
- etree.Comment,
- )
-
- # Split all the text nodes into paragraphs (by splitting on new
- # lines)
- text_nodes = (
- re.sub(r"\s+", "\n", el).strip()
- for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
- )
- og["og:description"] = summarize_paragraphs(text_nodes)
+ og["og:description"] = _calc_description(tree)
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
@@ -781,8 +770,48 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
return og
+def _calc_description(tree: "etree.Element") -> Optional[str]:
+ """
+ Calculate a text description based on an HTML document.
+
+ Grabs any text nodes which are inside the <body/> tag, unless they are within
+ an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+ if they are within a <script/> or <style/> tag.
+
+ This is a very very very coarse approximation to a plain text render of the page.
+
+ Args:
+ tree: The parsed HTML document.
+
+ Returns:
+ The plain text description, or None if one cannot be generated.
+ """
+ # We don't just use XPATH here as that is slow on some machines.
+
+ from lxml import etree
+
+ TAGS_TO_REMOVE = (
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment,
+ )
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r"\s+", "\n", el).strip()
+ for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+ )
+ return summarize_paragraphs(text_nodes)
+
+
def _iterate_over_text(
- tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+ tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
@@ -840,11 +869,25 @@ def _is_html(content_type: str) -> bool:
)
+def _is_json(content_type: str) -> bool:
+ return content_type.lower().startswith("application/json")
+
+
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
- # Try to get a summary of between 200 and 500 words, respecting
- # first paragraph and then word boundaries.
+ """
+ Try to get a summary respecting first paragraph and then word boundaries.
+
+ Args:
+ text_nodes: The paragraphs to summarize.
+ min_size: The minimum number of words to include.
+ max_size: The maximum number of words to include.
+
+ Returns:
+ A summary of the text nodes, or None if that was not possible.
+ """
+
# TODO: Respect sentences?
description = ""
@@ -867,7 +910,7 @@ def summarize_paragraphs(
new_desc = ""
# This splits the paragraph into words, but keeping the
- # (preceeding) whitespace intact so we can easily concat
+ # (preceding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer(r"\s*\S+", description):
word = match.group()
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 0ff6ad3c0c..6c9969e55f 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
- async def store():
+ async def store() -> None:
try:
return await maybe_awaitable(
self.backend.store_file(path, file_info)
@@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
- def __str__(self):
+ def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 12bd745cb2..22f43d8531 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -26,6 +26,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
+ ThumbnailInfo,
parse_media_id,
respond_404,
respond_with_file,
@@ -114,7 +115,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_id,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
server_name=None,
)
@@ -149,11 +150,12 @@ class ThumbnailResource(DirectServeJsonResource):
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -173,7 +175,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height,
desired_method,
desired_type,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
)
if file_path:
@@ -210,11 +212,12 @@ class ThumbnailResource(DirectServeJsonResource):
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -271,7 +274,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_info["filesystem_id"],
- url_cache=None,
+ url_cache=False,
server_name=server_name,
)
@@ -285,7 +288,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos: List[Dict[str, Any]],
media_id: str,
file_id: str,
- url_cache: Optional[str] = None,
+ url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
@@ -299,7 +302,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos:
@@ -318,13 +321,16 @@ class ThumbnailResource(DirectServeJsonResource):
respond_404(request)
return
+ # The thumbnail property must exist.
+ assert file_info.thumbnail is not None
+
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
return
@@ -351,18 +357,18 @@ class ThumbnailResource(DirectServeJsonResource):
server_name,
file_id=file_id,
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
@@ -370,8 +376,8 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
else:
logger.info("Failed to find any generated thumbnails")
@@ -385,7 +391,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
- url_cache: Optional[str],
+ url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
@@ -398,7 +404,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
@@ -495,12 +501,13 @@ class ThumbnailResource(DirectServeJsonResource):
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- thumbnail_length=thumbnail_info["thumbnail_length"],
+ thumbnail=ThumbnailInfo(
+ width=thumbnail_info["thumbnail_width"],
+ height=thumbnail_info["thumbnail_height"],
+ type=thumbnail_info["thumbnail_type"],
+ method=thumbnail_info["thumbnail_method"],
+ length=thumbnail_info["thumbnail_length"],
+ ),
)
# No matching thumbnail was found.
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a65e9e1802..df54a40649 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
- def set_limits(max_image_pixels: int):
+ def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index 67c1ed1f5f..1c1c7b3613 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request
@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: Request):
+ async def _async_render_POST(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 36ba401656..81fec39659 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -13,16 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class OIDCResource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index 7785f17e90..4f375cb74c 100644
--- a/synapse/rest/synapse/client/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_oidc_callback(request)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# the auth response can be returned via an x-www-form-urlencoded form instead
# of GET params, as per
# https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d30b478b98..28ae083497 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Generator, List, Tuple
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -27,6 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.templates import build_jinja_env
if TYPE_CHECKING:
@@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request: Request):
+ async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
localpart = parse_string(request, "username", required=True)
session_id = get_username_mapping_session_cookie_from_request(request)
@@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: SynapseRequest):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None
diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 781ccb237c..3f247e6a2c 100644
--- a/synapse/rest/synapse/client/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -13,17 +13,21 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class SAML2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index b37c7083dc..64378ed57b 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
import saml2.metadata
from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SAML2MetadataResource(Resource):
@@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string(
configfile=None, config=self.sp_config
)
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index 774ccd870f..47d2a6a229 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -15,7 +15,10 @@
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
@@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 6a66a88c53..c80a3a99aa 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,26 +13,26 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import set_cors_headers
+from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class WellKnownBuilder:
- """Utility to construct the well-known response
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._config = hs.config
- def get_well_known(self):
+ def get_well_known(self) -> Optional[JsonDict]:
# if we don't have a public_baseurl, we can't help much here.
if self._config.server.public_baseurl is None:
return None
@@ -52,11 +52,11 @@ class WellKnownResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs)
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
set_cors_headers(request)
r = self._well_known_builder.get_well_known()
if not r:
|