diff --git a/changelog.d/10658.bugfix b/changelog.d/10658.bugfix
new file mode 100644
index 0000000000..a59d402933
--- /dev/null
+++ b/changelog.d/10658.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where room avatars were not included in email notifications.
diff --git a/changelog.d/10707.misc b/changelog.d/10707.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10707.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10712.feature b/changelog.d/10712.feature
new file mode 100644
index 0000000000..d04db6f26f
--- /dev/null
+++ b/changelog.d/10712.feature
@@ -0,0 +1 @@
+Skip final GC at shutdown to improve restart performance.
diff --git a/changelog.d/10714.feature b/changelog.d/10714.feature
new file mode 100644
index 0000000000..7d18f5c133
--- /dev/null
+++ b/changelog.d/10714.feature
@@ -0,0 +1 @@
+Allow configuration of the oEmbed URLs used for URL previews.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 935841dbfa..e155b978d8 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1075,6 +1075,27 @@ url_preview_accept_language:
# - en
+# oEmbed allows for easier embedding content from a website. It can be
+# used for generating URLs previews of services which support it.
+#
+oembed:
+ # A default list of oEmbed providers is included with Synapse.
+ #
+ # Uncomment the following to disable using these default oEmbed URLs.
+ # Defaults to 'false'.
+ #
+ #disable_default_providers: true
+
+ # Additional files with oEmbed configuration (each should be in the
+ # form of providers.json).
+ #
+ # By default, this list is empty (so only the default providers.json
+ # is used).
+ #
+ #additional_providers:
+ # - oembed/my_providers.json
+
+
## Captcha ##
# See docs/CAPTCHA_SETUP.md for full details of configuring this.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 39e28aff9f..6fc14930d1 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import atexit
import gc
import logging
import os
@@ -403,6 +404,12 @@ async def start(hs: "HomeServer"):
gc.collect()
gc.freeze()
+ # Speed up shutdowns by freezing all allocated objects. This moves everything
+ # into the permanent generation and excludes them from the final GC.
+ # Unfortunately only works on Python 3.7
+ if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+ atexit.register(gc.freeze)
+
def setup_sentry(hs):
"""Enable sentry integration, if enabled in configuration
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 1f42a51857..442f1b9ac0 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -30,6 +30,7 @@ from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
from .modules import ModulesConfig
+from .oembed import OembedConfig
from .oidc import OIDCConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
@@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig):
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
+ OembedConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
new file mode 100644
index 0000000000..09267b5eef
--- /dev/null
+++ b/synapse/config/oembed.py
@@ -0,0 +1,180 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import re
+from typing import Any, Dict, Iterable, List, Pattern
+from urllib import parse as urlparse
+
+import attr
+import pkg_resources
+
+from synapse.types import JsonDict
+
+from ._base import Config, ConfigError
+from ._util import validate_config
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class OEmbedEndpointConfig:
+ # The API endpoint to fetch.
+ api_endpoint: str
+ # The patterns to match.
+ url_patterns: List[Pattern]
+
+
+class OembedConfig(Config):
+ """oEmbed Configuration"""
+
+ section = "oembed"
+
+ def read_config(self, config, **kwargs):
+ oembed_config: Dict[str, Any] = config.get("oembed") or {}
+
+ # A list of patterns which will be used.
+ self.oembed_patterns: List[OEmbedEndpointConfig] = list(
+ self._parse_and_validate_providers(oembed_config)
+ )
+
+ def _parse_and_validate_providers(
+ self, oembed_config: dict
+ ) -> Iterable[OEmbedEndpointConfig]:
+ """Extract and parse the oEmbed providers from the given JSON file.
+
+ Returns a generator which yields the OidcProviderConfig objects
+ """
+ # Whether to use the packaged providers.json file.
+ if not oembed_config.get("disable_default_providers") or False:
+ providers = json.load(
+ pkg_resources.resource_stream("synapse", "res/providers.json")
+ )
+ yield from self._parse_and_validate_provider(
+ providers, config_path=("oembed",)
+ )
+
+ # The JSON files which includes additional provider information.
+ for i, file in enumerate(oembed_config.get("additional_providers") or []):
+ # TODO Error checking.
+ with open(file) as f:
+ providers = json.load(f)
+
+ yield from self._parse_and_validate_provider(
+ providers,
+ config_path=(
+ "oembed",
+ "additional_providers",
+ f"<item {i}>",
+ ),
+ )
+
+ def _parse_and_validate_provider(
+ self, providers: List[JsonDict], config_path: Iterable[str]
+ ) -> Iterable[OEmbedEndpointConfig]:
+ # Ensure it is the proper form.
+ validate_config(
+ _OEMBED_PROVIDER_SCHEMA,
+ providers,
+ config_path=config_path,
+ )
+
+ # Parse it and yield each result.
+ for provider in providers:
+ # Each provider might have multiple API endpoints, each which
+ # might have multiple patterns to match.
+ for endpoint in provider["endpoints"]:
+ api_endpoint = endpoint["url"]
+ patterns = [
+ self._glob_to_pattern(glob, config_path)
+ for glob in endpoint["schemes"]
+ ]
+ yield OEmbedEndpointConfig(api_endpoint, patterns)
+
+ def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern:
+ """
+ Convert the glob into a sane regular expression to match against. The
+ rules followed will be slightly different for the domain portion vs.
+ the rest.
+
+ 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+ 2. The domain can have globs, but we limit it to characters that can
+ reasonably be a domain part.
+ TODO: This does not attempt to handle Unicode domain names.
+ TODO: The domain should not allow wildcard TLDs.
+ 3. Other parts allow a glob to be any one, or more, characters.
+ """
+ results = urlparse.urlparse(glob)
+
+ # Ensure the scheme does not have wildcards (and is a sane scheme).
+ if results.scheme not in {"http", "https"}:
+ raise ConfigError(f"Insecure oEmbed scheme: {results.scheme}", config_path)
+
+ pattern = urlparse.urlunparse(
+ [
+ results.scheme,
+ re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+ ]
+ + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+ )
+ return re.compile(pattern)
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ # oEmbed allows for easier embedding content from a website. It can be
+ # used for generating URLs previews of services which support it.
+ #
+ oembed:
+ # A default list of oEmbed providers is included with Synapse.
+ #
+ # Uncomment the following to disable using these default oEmbed URLs.
+ # Defaults to 'false'.
+ #
+ #disable_default_providers: true
+
+ # Additional files with oEmbed configuration (each should be in the
+ # form of providers.json).
+ #
+ # By default, this list is empty (so only the default providers.json
+ # is used).
+ #
+ #additional_providers:
+ # - oembed/my_providers.json
+ """
+
+
+_OEMBED_PROVIDER_SCHEMA = {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "provider_name": {"type": "string"},
+ "provider_url": {"type": "string"},
+ "endpoints": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "schemes": {
+ "type": "array",
+ "items": {"type": "string"},
+ },
+ "url": {"type": "string"},
+ "formats": {"type": "array", "items": {"type": "string"}},
+ "discovery": {"type": "boolean"},
+ },
+ "required": ["schemes", "url"],
+ },
+ },
+ },
+ "required": ["provider_name", "provider_url", "endpoints"],
+ },
+}
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 941fb238b7..b0834720ad 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -258,7 +258,7 @@ class Mailer:
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
- rooms = []
+ rooms: List[Dict[str, Any]] = []
for r in rooms_in_order:
roomvars = await self._get_room_vars(
@@ -362,6 +362,7 @@ class Mailer:
"notifs": [],
"invite": is_invite,
"link": self._make_room_link(room_id),
+ "avatar_url": await self._get_room_avatar(room_state_ids),
}
if not is_invite:
@@ -393,6 +394,27 @@ class Mailer:
return room_vars
+ async def _get_room_avatar(
+ self,
+ room_state_ids: StateMap[str],
+ ) -> Optional[str]:
+ """
+ Retrieve the avatar url for this room---if it exists.
+
+ Args:
+ room_state_ids: The event IDs of the current room state.
+
+ Returns:
+ room's avatar url if it's present and a string; otherwise None.
+ """
+ event_id = room_state_ids.get((EventTypes.RoomAvatar, ""))
+ if event_id:
+ ev = await self.store.get_event(event_id)
+ url = ev.content.get("url")
+ if isinstance(url, str):
+ return url
+ return None
+
async def _get_notif_vars(
self,
notif: Dict[str, Any],
diff --git a/synapse/res/providers.json b/synapse/res/providers.json
new file mode 100644
index 0000000000..f1838f9559
--- /dev/null
+++ b/synapse/res/providers.json
@@ -0,0 +1,17 @@
+[
+ {
+ "provider_name": "Twitter",
+ "provider_url": "http://www.twitter.com/",
+ "endpoints": [
+ {
+ "schemes": [
+ "https://twitter.com/*/status/*",
+ "https://*.twitter.com/*/status/*",
+ "https://twitter.com/*/moments/*",
+ "https://*.twitter.com/*/moments/*"
+ ],
+ "url": "https://publish.twitter.com/oembed"
+ }
+ ]
+ }
+]
\ No newline at end of file
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 7517e9304e..d1badbdf3b 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, account_data_type):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet):
return 200, {}
- async def on_GET(self, request, user_id, account_data_type):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
@@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet):
"/account_data/(?P<account_data_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, account_data_type):
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet):
return 200, {}
- async def on_GET(self, request, user_id, room_id, account_data_type):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
@@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet):
return 200, event
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index c3667ff8aa..0950f43f2f 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -156,7 +156,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id: str,
category_id: Optional[str],
room_id: str,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -188,7 +188,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -451,7 +451,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -674,7 +674,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -706,7 +706,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: SynapseRequest, group_id, user_id
+ self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -738,7 +738,7 @@ class GroupAdminUsersKickServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: SynapseRequest, group_id, user_id
+ self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index d9ab836cd8..9770413c61 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,13 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet):
"/(?P<event_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id, receipt_type, event_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
@@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 7b5f49d635..a28acd4041 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -14,7 +14,9 @@
# limitations under the License.
import logging
import random
-from typing import List, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.server import Request
import synapse
import synapse.api.auth
@@ -29,15 +31,13 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
-from synapse.config.captcha import CaptchaConfig
-from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
@@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html
)
- async def on_GET(self, request, medium):
+ async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet):
)
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False)
- kind = b"user"
- if b"kind" in request.args:
- kind = request.args[b"kind"][0]
+ kind = parse_string(request, "kind", default="user")
- if kind == b"guest":
+ if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr)
return ret
- elif kind != b"user":
+ elif kind != "user":
raise UnrecognizedRequestError(
- "Do not understand membership kind: %s" % (kind.decode("utf8"),)
+ f"Do not understand membership kind: {kind}",
)
if self._msc2918_enabled:
@@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet):
async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False
- ):
+ ) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
)
@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
- ):
+ ) -> JsonDict:
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result
- async def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(
+ self, params: JsonDict, address: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows(
- # technically `config` has to provide *all* of these interfaces, not just one
- config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
- auth_handler: AuthHandler,
+ config: HomeServerConfig, auth_handler: AuthHandler
) -> List[List[str]]:
"""Get a suitable flows list for registration
@@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 0821cd285f..0b0711c03c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
"""
import logging
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
http_server.register_paths(
"POST",
client_patterns(self.PATTERN + "$", releases=()),
@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
self.__class__.__name__,
)
- def on_PUT(self, request, *args, **kwargs):
+ def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
- request, self.on_PUT_or_POST, request, *args, **kwargs
+ request,
+ self.on_PUT_or_POST,
+ request,
+ room_id,
+ parent_id,
+ relation_type,
+ event_type,
+ txn_id,
)
async def on_PUT_or_POST(
- self, request, room_id, parent_id, relation_type, event_type, txn_id=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ key: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return 200, return_value
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c5c54564be..9b0c546505 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,11 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
+from twisted.web.server import Request
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -30,6 +32,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
- def on_PUT(self, request, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
info, _ = await self._room_creation_handler.create_room(
@@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
return 200, info
- def get_room_config(self, request):
+ def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__,
)
- def on_GET_no_state_key(self, request, room_id, event_type):
+ def on_GET_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_GET(request, room_id, event_type, "")
- def on_PUT_no_state_key(self, request, room_id, event_type):
+ def on_PUT_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
- async def on_GET(self, request, room_id, event_type, state_key):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
@@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ # Format must be event or content, per the parse_string call above.
+ raise RuntimeError(f"Unknown format: {format:r}.")
+
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if txn_id:
@@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- async def on_POST(self, request, room_id, event_type, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- event_dict = {
+ event_dict: JsonDict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
+ # Twisted will have processed the args by now.
+ assert request.args is not None
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
@@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_GET(self, request, room_id, event_type, txn_id):
+ def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Tuple[int, str]:
return 200, "Not implemented"
- def on_PUT(self, request, room_id, event_type, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id]
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
- def on_PUT(self, request, room_identifier, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_identifier: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
try:
@@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
return 200, data
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server")
@@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
# TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
@@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
users_with_profile = await self.message_handler.get_joined_members(
@@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
+ # Twisted will have processed the args by now.
+ assert request.args is not None
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
@@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
events = await self.message_handler.get_state_events(
@@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
@@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
event = await self.event_handler.get_event(
@@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ event_dict = await self._event_serializer.serialize_event(event, time_now)
+ return 200, event_dict
- return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet):
@@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
- def on_PUT(self, request, room_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/[invite|join|leave]
PATTERNS = (
"/rooms/(?P<room_id>[^/]*)/"
@@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ membership_action: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
@@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
- def _has_3pid_invite_keys(self, content):
+ def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
- def on_PUT(self, request, room_id, membership_action, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, event_id, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_id: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_PUT(self, request, room_id, event_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
hs.config.worker.writers.typing == hs.get_instance_name()
)
- async def on_PUT(self, request, room_id, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer:
@@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
alias_list = await self.directory_handler.get_aliases_for_room(
@@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
-def register_txn_path(servlet, regex_string, http_server, with_get=False):
+def register_txn_path(
+ servlet: RestServlet,
+ regex_string: str,
+ http_server: HttpServer,
+ with_get: bool = False,
+) -> None:
"""Registers a transaction-based path.
This registers two paths:
@@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
POST regex_string
Args:
- regex_string (str): The regex string to register. Must NOT have a
- trailing $ as this string will be appended to.
- http_server : The http_server to register paths with.
+ regex_string: The regex string to register. Must NOT have a
+ trailing $ as this string will be appended to.
+ http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
+ on_POST = getattr(servlet, "on_POST", None)
+ on_PUT = getattr(servlet, "on_PUT", None)
+ if on_POST is None or on_PUT is None:
+ raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
- servlet.on_POST,
+ on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_PUT,
+ on_PUT,
servlet.__class__.__name__,
)
+ on_GET = getattr(servlet, "on_GET", None)
if with_get:
+ if on_GET is None:
+ raise RuntimeError(
+ "register_txn_path called with with_get = True, but no on_GET method exists"
+ )
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_GET,
+ on_GET,
servlet.__class__.__name__,
)
@@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
-def register_servlets(hs: "HomeServer", http_server, is_worker=False):
+def register_servlets(
+ hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
+) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomForgetRestServlet(hs).register(http_server)
-def register_deprecated_servlets(hs, http_server):
+def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomInitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
new file mode 100644
index 0000000000..afe41823e4
--- /dev/null
+++ b/synapse/rest/media/v1/oembed.py
@@ -0,0 +1,135 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+import attr
+
+from synapse.http.client import SimpleHttpClient
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html: Optional[str]
+ url: Optional[str]
+ title: Optional[str]
+ # Number of seconds to cache the content.
+ cache_age: int
+
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
+
+class OEmbedProvider:
+ """
+ A helper for accessing oEmbed content.
+
+ It can be used to check if a URL should be accessed via oEmbed and for
+ requesting/parsing oEmbed content.
+ """
+
+ def __init__(self, hs: "HomeServer", client: SimpleHttpClient):
+ self._oembed_patterns = {}
+ for oembed_endpoint in hs.config.oembed.oembed_patterns:
+ for pattern in oembed_endpoint.url_patterns:
+ self._oembed_patterns[pattern] = oembed_endpoint.api_endpoint
+ self._client = client
+
+ def get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Args:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in self._oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Args:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+ result = await self._client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ # Note that only the JSON format is supported.
+ args={"url": url},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0f051d4041..317d333b12 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -25,8 +25,6 @@ import traceback
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
-import attr
-
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
@@ -43,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@@ -71,63 +70,6 @@ OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
-# A map of globs to API endpoints.
-_oembed_globs = {
- # Twitter.
- "https://publish.twitter.com/oembed": [
- "https://twitter.com/*/status/*",
- "https://*.twitter.com/*/status/*",
- "https://twitter.com/*/moments/*",
- "https://*.twitter.com/*/moments/*",
- # Include the HTTP versions too.
- "http://twitter.com/*/status/*",
- "http://*.twitter.com/*/status/*",
- "http://twitter.com/*/moments/*",
- "http://*.twitter.com/*/moments/*",
- ],
-}
-# Convert the globs to regular expressions.
-_oembed_patterns = {}
-for endpoint, globs in _oembed_globs.items():
- for glob in globs:
- # Convert the glob into a sane regular expression to match against. The
- # rules followed will be slightly different for the domain portion vs.
- # the rest.
- #
- # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
- # 2. The domain can have globs, but we limit it to characters that can
- # reasonably be a domain part.
- # TODO: This does not attempt to handle Unicode domain names.
- # 3. Other parts allow a glob to be any one, or more, characters.
- results = urlparse.urlparse(glob)
-
- # Ensure the scheme does not have wildcards (and is a sane scheme).
- if results.scheme not in {"http", "https"}:
- raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
-
- pattern = urlparse.urlunparse(
- [
- results.scheme,
- re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
- ]
- + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
- )
- _oembed_patterns[re.compile(pattern)] = endpoint
-
-
-@attr.s(slots=True)
-class OEmbedResult:
- # Either HTML content or URL must be provided.
- html = attr.ib(type=Optional[str])
- url = attr.ib(type=Optional[str])
- title = attr.ib(type=Optional[str])
- # Number of seconds to cache the content.
- cache_age = attr.ib(type=int)
-
-
-class OEmbedError(Exception):
- """An error occurred processing the oEmbed object."""
-
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
@@ -157,6 +99,8 @@ class PreviewUrlResource(DirectServeJsonResource):
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
+ self._oembed = OEmbedProvider(hs, self.client)
+
# We run the background jobs if we're the instance specified (or no
# instance is specified, where we assume there is only one instance
# serving media).
@@ -367,87 +311,6 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
- def _get_oembed_url(self, url: str) -> Optional[str]:
- """
- Check whether the URL should be downloaded as oEmbed content instead.
-
- Args:
- url: The URL to check.
-
- Returns:
- A URL to use instead or None if the original URL should be used.
- """
- for url_pattern, endpoint in _oembed_patterns.items():
- if url_pattern.fullmatch(url):
- return endpoint
-
- # No match.
- return None
-
- async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
- """
- Request content from an oEmbed endpoint.
-
- Args:
- endpoint: The oEmbed API endpoint.
- url: The URL to pass to the API.
-
- Returns:
- An object representing the metadata returned.
-
- Raises:
- OEmbedError if fetching or parsing of the oEmbed information fails.
- """
- try:
- logger.debug("Trying to get oEmbed content for url '%s'", url)
- result = await self.client.get_json(
- endpoint,
- # TODO Specify max height / width.
- # Note that only the JSON format is supported.
- args={"url": url},
- )
-
- # Ensure there's a version of 1.0.
- if result.get("version") != "1.0":
- raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
- oembed_type = result.get("type")
-
- # Ensure the cache age is None or an int.
- cache_age = result.get("cache_age")
- if cache_age:
- cache_age = int(cache_age)
-
- oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
-
- # HTML content.
- if oembed_type == "rich":
- oembed_result.html = result.get("html")
- return oembed_result
-
- if oembed_type == "photo":
- oembed_result.url = result.get("url")
- return oembed_result
-
- # TODO Handle link and video types.
-
- if "thumbnail_url" in result:
- oembed_result.url = result.get("thumbnail_url")
- return oembed_result
-
- raise OEmbedError("Incompatible oEmbed information.")
-
- except OEmbedError as e:
- # Trap OEmbedErrors first so we can directly re-raise them.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- raise
-
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
- raise OEmbedError() from e
-
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -459,11 +322,11 @@ class PreviewUrlResource(DirectServeJsonResource):
# If this URL can be accessed via oEmbed, use that instead.
url_to_download: Optional[str] = url
- oembed_url = self._get_oembed_url(url)
+ oembed_url = self._oembed.get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
try:
- oembed_result = await self._get_oembed_content(oembed_url, url)
+ oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
if oembed_result.url:
url_to_download = oembed_result.url
elif oembed_result.html:
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index c4ba13a6b2..fa8018e5a7 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import email.message
import os
+from typing import Dict, List, Sequence, Tuple
import attr
import pkg_resources
@@ -70,9 +71,10 @@ class EmailPusherTests(HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
# List[Tuple[Deferred, args, kwargs]]
- self.email_attempts = []
+ self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = []
def sendmail(*args, **kwargs):
+ # This mocks out synapse.reactor.send_email._sendmail.
d = Deferred()
self.email_attempts.append((d, args, kwargs))
return d
@@ -255,6 +257,39 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_room_notifications_include_avatar(self):
+ # Create a room and set its avatar.
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.send_state(
+ room, "m.room.avatar", {"url": "mxc://DUMMY_MEDIA_ID"}, self.access_token
+ )
+
+ # Invite two other uses.
+ for other in self.others:
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=room, user=other.id, tok=other.token)
+
+ # The other users send some messages.
+ # TODO It seems that two messages are required to trigger an email?
+ self.helper.send(room, body="Alpha", tok=self.others[0].token)
+ self.helper.send(room, body="Beta", tok=self.others[1].token)
+
+ # We should get emailed about those messages
+ args, kwargs = self._check_for_mail()
+
+ # That email should contain the room's avatar
+ msg: bytes = args[5]
+ # Multipart: plain text, base 64 encoded; html, base 64 encoded
+ html = (
+ email.message_from_bytes(msg)
+ .get_payload()[1]
+ .get_payload(decode=True)
+ .decode()
+ )
+ self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
+
def test_empty_room(self):
"""All users leaving a room shouldn't cause the pusher to break."""
# Create a simple room with two users
@@ -388,9 +423,14 @@ class EmailPusherTests(HomeserverTestCase):
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
- def _check_for_mail(self):
- """Check that the user receives an email notification"""
+ def _check_for_mail(self) -> Tuple[Sequence, Dict]:
+ """
+ Assert that synapse sent off exactly one email notification.
+ Returns:
+ args and kwargs passed to synapse.reactor.send_email._sendmail for
+ that notification.
+ """
# Get the stream ordering before it gets sent
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
@@ -413,8 +453,9 @@ class EmailPusherTests(HomeserverTestCase):
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
+ deferred, sendmail_args, sendmail_kwargs = self.email_attempts[0]
# Make the email succeed
- self.email_attempts[0][0].callback(True)
+ deferred.callback(True)
self.pump()
# One email was attempted to be sent
@@ -430,3 +471,4 @@ class EmailPusherTests(HomeserverTestCase):
# Reset the attempts.
self.email_attempts = []
+ return sendmail_args, sendmail_kwargs
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index d3ef7bb4c6..7fa9027227 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -14,13 +14,14 @@
import json
import os
import re
-from unittest.mock import patch
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
+from synapse.config.oembed import OEmbedEndpointConfig
+
from tests import unittest
from tests.server import FakeTransport
@@ -81,6 +82,19 @@ class URLPreviewTests(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
+ # After the hs is created, modify the parsed oEmbed config (to avoid
+ # messing with files).
+ #
+ # Note that HTTP URLs are used to avoid having to deal with TLS in tests.
+ hs.config.oembed.oembed_patterns = [
+ OEmbedEndpointConfig(
+ api_endpoint="http://publish.twitter.com/oembed",
+ url_patterns=[
+ re.compile(r"http://twitter\.com/.+/status/.+"),
+ ],
+ )
+ ]
+
return hs
def prepare(self, reactor, clock, hs):
@@ -544,123 +558,101 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def test_oembed_photo(self):
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
- # Route the HTTP version to an HTTP endpoint so that the tests work.
- with patch.dict(
- "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
- {
- re.compile(
- r"http://twitter\.com/.+/status/.+"
- ): "http://publish.twitter.com/oembed",
- },
- clear=True,
- ):
-
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "photo",
- "url": "http://cdn.twitter.com/matrixdotorg",
- }
- oembed_content = json.dumps(result).encode("utf-8")
-
- end_content = (
- b"<html><head>"
- b"<title>Some Title</title>"
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- channel = self.make_request(
- "GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(oembed_content),)
- + oembed_content
- )
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
- self.pump()
-
- client = self.reactor.tcpClients[1][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
)
+ % (len(oembed_content),)
+ + oembed_content
+ )
- self.pump()
+ self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
- # Route the HTTP version to an HTTP endpoint so that the tests work.
- with patch.dict(
- "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
- {
- re.compile(
- r"http://twitter\.com/.+/status/.+"
- ): "http://publish.twitter.com/oembed",
- },
- clear=True,
- ):
-
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "rich",
- "html": "<div>Content Preview</div>",
- }
- end_content = json.dumps(result).encode("utf-8")
-
- channel = self.make_request(
- "GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
)
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
|