diff --git a/changelog.d/9726.bugfix b/changelog.d/9726.bugfix
new file mode 100644
index 0000000000..4ba0b24327
--- /dev/null
+++ b/changelog.d/9726.bugfix
@@ -0,0 +1 @@
+Fixes the OIDC SSO flow when using a `public_baseurl` value including a non-root URL path.
\ No newline at end of file
diff --git a/changelog.d/9817.misc b/changelog.d/9817.misc
new file mode 100644
index 0000000000..8aa8895f05
--- /dev/null
+++ b/changelog.d/9817.misc
@@ -0,0 +1 @@
+Fix a long-standing bug which caused `max_upload_size` to not be correctly enforced.
diff --git a/changelog.d/9874.misc b/changelog.d/9874.misc
new file mode 100644
index 0000000000..ba1097e65e
--- /dev/null
+++ b/changelog.d/9874.misc
@@ -0,0 +1 @@
+Pass a reactor into `SynapseSite` to make testing easier.
diff --git a/changelog.d/9876.misc b/changelog.d/9876.misc
new file mode 100644
index 0000000000..28390e32e6
--- /dev/null
+++ b/changelog.d/9876.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.
diff --git a/changelog.d/9878.misc b/changelog.d/9878.misc
new file mode 100644
index 0000000000..927876852d
--- /dev/null
+++ b/changelog.d/9878.misc
@@ -0,0 +1 @@
+Remove redundant `_PushHTTPChannel` test class.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2d845d0d5c..efc926d094 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
from twisted.web.server import Request
-import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import StateMap, UserID
+from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -68,7 +70,7 @@ class Auth:
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -88,13 +90,13 @@ class Auth:
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
- ):
+ ) -> None:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_by_id = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
@@ -151,17 +153,11 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- async def check_host_in_room(self, room_id, host):
+ async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = await self.store.is_host_joined(room_id, host)
- return latest_event_ids
-
- def can_federate(self, event, auth_events):
- creation_event = auth_events.get((EventTypes.Create, ""))
+ return await self.store.is_host_joined(room_id, host)
- return creation_event.content.get("m.federate", True) is True
-
- def get_public_keys(self, invite_event):
+ def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
return event_auth.get_public_keys(invite_event)
async def get_user_by_req(
@@ -170,7 +166,7 @@ class Auth:
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
- ) -> synapse.types.Requester:
+ ) -> Requester:
"""Get a registered user's ID.
Args:
@@ -196,7 +192,7 @@ class Auth:
access_token = self.get_access_token_from_request(request)
user_id, app_service = await self._get_appservice_user_id(request)
- if user_id:
+ if user_id and app_service:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@@ -206,9 +202,7 @@ class Auth:
device_id="dummy-device", # stubbed
)
- requester = synapse.types.create_requester(
- user_id, app_service=app_service
- )
+ requester = create_requester(user_id, app_service=app_service)
request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
@@ -251,7 +245,7 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
- requester = synapse.types.create_requester(
+ requester = create_requester(
user_info.user_id,
token_id,
is_guest,
@@ -271,7 +265,9 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
- async def _get_appservice_user_id(self, request):
+ async def _get_appservice_user_id(
+ self, request: Request
+ ) -> Tuple[Optional[str], Optional[ApplicationService]]:
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@@ -283,6 +279,9 @@ class Auth:
if ip_address not in app_service.ip_range_whitelist:
return None, None
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
if b"user_id" not in request.args:
return app_service.sender, app_service
@@ -387,7 +386,9 @@ class Auth:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise InvalidClientTokenError("Invalid macaroon passed.")
- def _parse_and_validate_macaroon(self, token, rights="access"):
+ def _parse_and_validate_macaroon(
+ self, token: str, rights: str = "access"
+ ) -> Tuple[str, bool]:
"""Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry.
@@ -432,15 +433,16 @@ class Auth:
return user_id, guest
- def validate_macaroon(self, macaroon, type_string, user_id):
+ def validate_macaroon(
+ self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
+ ) -> None:
"""
validate that a Macaroon is understood by and was signed by this server.
Args:
- macaroon(pymacaroons.Macaroon): The macaroon to validate
- type_string(str): The kind of token required (e.g. "access",
- "delete_pusher")
- user_id (str): The user_id required
+ macaroon: The macaroon to validate
+ type_string: The kind of token required (e.g. "access", "delete_pusher")
+ user_id: The user_id required
"""
v = pymacaroons.Verifier()
@@ -465,9 +467,7 @@ class Auth:
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
- request.requester = synapse.types.create_requester(
- service.sender, app_service=service
- )
+ request.requester = create_requester(service.sender, app_service=service)
return service
async def is_server_admin(self, user: UserID) -> bool:
@@ -519,7 +519,7 @@ class Auth:
return auth_ids
- async def check_can_change_room_list(self, room_id: str, user: UserID):
+ async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
@@ -554,11 +554,11 @@ class Auth:
return user_level >= send_level
@staticmethod
- def has_access_token(request: Request):
+ def has_access_token(request: Request) -> bool:
"""Checks if the request has an access_token.
Returns:
- bool: False if no access_token was given, True otherwise.
+ False if no access_token was given, True otherwise.
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None
@@ -568,13 +568,13 @@ class Auth:
return bool(query_params) or bool(auth_headers)
@staticmethod
- def get_access_token_from_request(request: Request):
+ def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request.
Args:
request: The http request.
Returns:
- unicode: The access_token
+ The access_token
Raises:
MissingClientTokenError: If there isn't a single access_token in the
request
@@ -649,5 +649,5 @@ class Auth:
% (user_id, room_id),
)
- def check_auth_blocking(self, *args, **kwargs):
- return self._auth_blocking.check_auth_blocking(*args, **kwargs)
+ async def check_auth_blocking(self, *args, **kwargs) -> None:
+ await self._auth_blocking.check_auth_blocking(*args, **kwargs)
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index a8df60cb89..e6bced93d5 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -13,18 +13,21 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import Requester
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AuthBlocking:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
@@ -43,7 +46,7 @@ class AuthBlocking:
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
- ):
+ ) -> None:
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 31a59bceec..936b6534b4 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,9 @@
"""Contains constants from the specification."""
+# the max size of a (canonical-json-encoded) event
+MAX_PDU_SIZE = 65536
+
# the "depth" field on events is limited to 2**63 - 1
MAX_DEPTH = 2 ** 63 - 1
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 2113c4f370..638e01c1b2 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -30,9 +30,10 @@ from twisted.internet import defer, error, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
+from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
-from synapse.config.server import ListenerConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -288,7 +289,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer"):
"""
Start a Synapse server or worker.
@@ -300,7 +301,6 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
Args:
hs: homeserver instance
- listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
@@ -336,7 +336,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
# It is now safe to start your Synapse.
- hs.start_listening(listeners)
+ hs.start_listening()
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
@@ -530,3 +530,25 @@ def sdnotify(state):
# this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET
# unless systemd is expecting us to notify it.
logger.warning("Unable to send notification to systemd: %s", e)
+
+
+def max_request_body_size(config: HomeServerConfig) -> int:
+ """Get a suitable maximum size for incoming HTTP requests"""
+
+ # Other than media uploads, the biggest request we expect to see is a fully-loaded
+ # /federation/v1/send request.
+ #
+ # The main thing in such a request is up to 50 PDUs, and up to 100 EDUs. PDUs are
+ # limited to 65536 bytes (possibly slightly more if the sender didn't use canonical
+ # json encoding); there is no specced limit to EDUs (see
+ # https://github.com/matrix-org/matrix-doc/issues/3121).
+ #
+ # in short, we somewhat arbitrarily limit requests to 200 * 64K (about 12.5M)
+ #
+ max_request_size = 200 * MAX_PDU_SIZE
+
+ # if we have a media repo enabled, we may need to allow larger uploads than that
+ if config.media.can_load_media_repo:
+ max_request_size = max(max_request_size, config.media.max_upload_size)
+
+ return max_request_size
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index eb256db749..68ae19c977 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -70,12 +70,6 @@ class AdminCmdSlavedStore(
class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore
- def _listen_http(self, listener_config):
- pass
-
- def start_listening(self, listeners):
- pass
-
async def export_data_command(hs, args):
"""Export data for a user.
@@ -232,7 +226,7 @@ def start(config_options):
async def run():
with LoggingContext("command"):
- _base.start(ss, [])
+ _base.start(ss)
await args.func(ss, args)
_base.start_worker_reactor(
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index d831f793b9..f3ae7ac69a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
import sys
-from typing import Dict, Iterable, Optional
+from typing import Dict, Optional
from twisted.internet import address
from twisted.web.resource import IResource
@@ -32,7 +32,7 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
-from synapse.app._base import register_start
+from synapse.app._base import max_request_body_size, register_start
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
@@ -367,14 +367,16 @@ class GenericWorkerServer(HomeServer):
listener_config,
root_resource,
self.version_string,
+ max_request_body_size=max_request_body_size(self.config),
+ reactor=self.get_reactor(),
),
reactor=self.get_reactor(),
)
logger.info("Synapse worker now listening on port %d", port)
- def start_listening(self, listeners: Iterable[ListenerConfig]):
- for listener in listeners:
+ def start_listening(self):
+ for listener in self.config.worker_listeners:
if listener.type == "http":
self._listen_http(listener)
elif listener.type == "manhole":
@@ -468,7 +470,7 @@ def start(config_options):
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
- register_start(_base.start, hs, config.worker_listeners)
+ register_start(_base.start, hs)
_base.start_worker_reactor("synapse-generic-worker", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index fd7958cecd..5bd439a3ba 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -17,7 +17,7 @@
import logging
import os
import sys
-from typing import Iterable, Iterator
+from typing import Iterator
from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource
@@ -36,7 +36,13 @@ from synapse.api.urls import (
WEB_CLIENT_PREFIX,
)
from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
+from synapse.app._base import (
+ listen_ssl,
+ listen_tcp,
+ max_request_body_size,
+ quit_with_error,
+ register_start,
+)
from synapse.config._base import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
@@ -126,19 +132,21 @@ class SynapseHomeServer(HomeServer):
else:
root_resource = OptionsResource()
- root_resource = create_resource_tree(resources, root_resource)
+ site = SynapseSite(
+ "synapse.access.%s.%s" % ("https" if tls else "http", site_tag),
+ site_tag,
+ listener_config,
+ create_resource_tree(resources, root_resource),
+ self.version_string,
+ max_request_body_size=max_request_body_size(self.config),
+ reactor=self.get_reactor(),
+ )
if tls:
ports = listen_ssl(
bind_addresses,
port,
- SynapseSite(
- "synapse.access.https.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
+ site,
self.tls_server_context_factory,
reactor=self.get_reactor(),
)
@@ -148,13 +156,7 @@ class SynapseHomeServer(HomeServer):
ports = listen_tcp(
bind_addresses,
port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
+ site,
reactor=self.get_reactor(),
)
logger.info("Synapse now listening on TCP port %d", port)
@@ -273,14 +275,14 @@ class SynapseHomeServer(HomeServer):
return resources
- def start_listening(self, listeners: Iterable[ListenerConfig]):
+ def start_listening(self):
if self.config.redis_enabled:
# If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client
# rather than a server).
self.get_tcp_replication().start_replication(self)
- for listener in listeners:
+ for listener in self.config.server.listeners:
if listener.type == "http":
self._listening_services.extend(
self._listener_http(self.config, listener)
@@ -413,7 +415,7 @@ def setup(config_options):
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
- await _base.start(hs, config.listeners)
+ await _base.start(hs)
hs.get_datastore().db_pool.updates.start_doing_background_updates()
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index b174e0df6d..813076dfe2 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -31,7 +31,6 @@ from twisted.logger import (
)
import synapse
-from synapse.app import _base as appbase
from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
@@ -318,6 +317,8 @@ def setup_logging(
# Perform one-time logging configuration.
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
# Add a SIGHUP handler to reload the logging configuration, if one is available.
+ from synapse.app import _base as appbase
+
appbase.register_sighup(_reload_logging_config, log_config_path)
# Log immediately so we can grep backwards.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 02b86b11a5..21ca7b33e3 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -235,7 +235,11 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
+
self.public_baseurl = config.get("public_baseurl")
+ if self.public_baseurl is not None:
+ if self.public_baseurl[-1] != "/":
+ self.public_baseurl += "/"
# Whether to enable user presence.
presence_config = config.get("presence") or {}
@@ -407,10 +411,6 @@ class ServerConfig(Config):
config_path=("federation_ip_range_blacklist",),
)
- if self.public_baseurl is not None:
- if self.public_baseurl[-1] != "/":
- self.public_baseurl += "/"
-
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before
# sending out any replication updates.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c831d9f73c..70c556566e 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -14,14 +14,14 @@
# limitations under the License.
import logging
-from typing import List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import MAX_PDU_SIZE, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@@ -205,7 +205,7 @@ def _check_size_limits(event: EventBase) -> None:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
- if len(encode_canonical_json(event.get_pdu_json())) > 65536:
+ if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
too_big("event")
@@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False
-def get_public_keys(invite_event):
+def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 1c4a43be0a..ee6e41c0e4 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -15,7 +15,7 @@
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
-from urllib.parse import urlencode
+from urllib.parse import urlencode, urlparse
import attr
import pymacaroons
@@ -68,8 +68,8 @@ logger = logging.getLogger(__name__)
#
# Here we have the names of the cookies, and the options we use to set them.
_SESSION_COOKIES = [
- (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
- (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
+ (b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
+ (b"oidc_session_no_samesite", b"HttpOnly"),
]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
@@ -279,6 +279,13 @@ class OidcProvider:
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
+ # Calculate the prefix for OIDC callback paths based on the public_baseurl.
+ # We'll insert this into the Path= parameter of any session cookies we set.
+ public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
+ self._callback_path_prefix = (
+ public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
+ )
+
self._oidc_attribute_requirements = provider.attribute_requirements
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
@@ -779,8 +786,13 @@ class OidcProvider:
for cookie_name, options in _SESSION_COOKIES:
request.cookies.append(
- b"%s=%s; Max-Age=3600; %s"
- % (cookie_name, cookie.encode("utf-8"), options)
+ b"%s=%s; Max-Age=3600; Path=%s; %s"
+ % (
+ cookie_name,
+ cookie.encode("utf-8"),
+ self._callback_path_prefix,
+ options,
+ )
)
metadata = await self.load_metadata()
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 32b5e19c09..671fd3fbcc 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,14 @@
import contextlib
import logging
import time
-from typing import Optional, Tuple, Type, Union
+from typing import Optional, Tuple, Union
import attr
from zope.interface import implementer
-from twisted.internet.interfaces import IAddress
+from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
+from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@@ -49,6 +50,7 @@ class SynapseRequest(Request):
* Redaction of access_token query-params in __repr__
* Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint.
+ * A limit to the size of request which will be accepted
It also provides a method `processing`, which returns a context manager. If this
method is called, the request won't be logged until the context manager is closed;
@@ -59,8 +61,9 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
- def __init__(self, channel, *args, **kw):
+ def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw)
+ self._max_request_body_size = max_request_body_size
self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -97,6 +100,18 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ def handleContentChunk(self, data):
+ # we should have a `content` by now.
+ assert self.content, "handleContentChunk() called before gotLength()"
+ if self.content.tell() + len(data) > self._max_request_body_size:
+ logger.warning(
+ "Aborting connection from %s because the request exceeds maximum size",
+ self.client,
+ )
+ self.transport.abortConnection()
+ return
+ super().handleContentChunk(data)
+
@property
def requester(self) -> Optional[Union[Requester, str]]:
return self._requester
@@ -485,29 +500,55 @@ class _XForwardedForAddress:
class SynapseSite(Site):
"""
- Subclass of a twisted http Site that does access logging with python's
- standard logging
+ Synapse-specific twisted http Site
+
+ This does two main things.
+
+ First, it replaces the requestFactory in use so that we build SynapseRequests
+ instead of regular t.w.server.Requests. All of the constructor params are really
+ just parameters for SynapseRequest.
+
+ Second, it inhibits the log() method called by Request.finish, since SynapseRequest
+ does its own logging.
"""
def __init__(
self,
- logger_name,
- site_tag,
+ logger_name: str,
+ site_tag: str,
config: ListenerConfig,
- resource,
+ resource: IResource,
server_version_string,
- *args,
- **kwargs,
+ max_request_body_size: int,
+ reactor: IReactorTime,
):
- Site.__init__(self, resource, *args, **kwargs)
+ """
+
+ Args:
+ logger_name: The name of the logger to use for access logs.
+ site_tag: A tag to use for this site - mostly in access logs.
+ config: Configuration for the HTTP listener corresponding to this site
+ resource: The base of the resource tree to be used for serving requests on
+ this site
+ server_version_string: A string to present for the Server header
+ max_request_body_size: Maximum request body length to allow before
+ dropping the connection
+ reactor: reactor to be used to manage connection timeouts
+ """
+ Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
assert config.http_options is not None
proxied = config.http_options.x_forwarded
- self.requestFactory = (
- XForwardedForRequest if proxied else SynapseRequest
- ) # type: Type[Request]
+ request_class = XForwardedForRequest if proxied else SynapseRequest
+
+ def request_factory(channel, queued) -> Request:
+ return request_class(
+ channel, max_request_body_size=max_request_body_size, queued=queued
+ )
+
+ self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 80f017a4dd..024a105bf2 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -51,8 +51,6 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
- # TODO: The checks here are a bit late. The content will have
- # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
diff --git a/synapse/server.py b/synapse/server.py
index 8c147be2b3..06570bb1ce 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -287,6 +287,14 @@ class HomeServer(metaclass=abc.ABCMeta):
if self.config.run_background_tasks:
self.setup_background_tasks()
+ def start_listening(self) -> None:
+ """Start the HTTP, manhole, metrics, etc listeners
+
+ Does nothing in this base class; overridden in derived classes to start the
+ appropriate listeners.
+ """
+ pass
+
def setup_background_tasks(self) -> None:
"""
Some handlers have side effects on instantiation (like registering
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 5c8ef444fa..0f0a74baca 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -17,8 +17,10 @@ from functools import wraps
from typing import (
Any,
Callable,
+ Collection,
Generic,
Iterable,
+ List,
Optional,
Type,
TypeVar,
@@ -83,15 +85,30 @@ class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
def __init__(
- self, prev_node, next_node, key, value, callbacks: Optional[set] = None
+ self,
+ prev_node,
+ next_node,
+ key,
+ value,
+ callbacks: Collection[Callable[[], None]] = (),
):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
- self.callbacks = callbacks or set()
-
self.memory = 0
+
+ # Set of callbacks to run when the node gets deleted. We store as a list
+ # rather than a set to keep memory usage down (and since we expect few
+ # entries per node the performance of checking for duplication in a list
+ # vs using a set is negligible).
+ #
+ # Note that we store this as an optional list to keep the memory
+ # footprint down. Empty lists are 56 bytes (and empty sets are 216 bytes).
+ self.callbacks = None # type: Optional[List[Callable[[], None]]]
+
+ self.add_callbacks(callbacks)
+
if TRACK_MEMORY_USAGE:
self.memory = (
_get_size_of(key)
@@ -101,6 +118,32 @@ class _Node:
)
self.memory += _get_size_of(self.memory, recurse=False)
+ def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
+ """Add to stored list of callbacks, removing duplicates."""
+
+ if not callbacks:
+ return
+
+ if not self.callbacks:
+ self.callbacks = []
+
+ for callback in callbacks:
+ if callback not in self.callbacks:
+ self.callbacks.append(callback)
+
+ def run_and_clear_callbacks(self) -> None:
+ """Run all callbacks and clear the stored set of callbacks. Used when
+ the node is being deleted.
+ """
+
+ if not self.callbacks:
+ return
+
+ for callback in self.callbacks:
+ callback()
+
+ self.callbacks = None
+
class LruCache(Generic[KT, VT]):
"""
@@ -213,10 +256,10 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
- def add_node(key, value, callbacks: Optional[set] = None):
+ def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
prev_node = list_root
next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value, callbacks or set())
+ node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@@ -250,9 +293,7 @@ class LruCache(Generic[KT, VT]):
deleted_len = size_callback(node.value)
cached_cache_len[0] -= deleted_len
- for cb in node.callbacks:
- cb()
- node.callbacks.clear()
+ node.run_and_clear_callbacks()
if TRACK_MEMORY_USAGE and metrics:
metrics.dec_memory_usage(node.memory)
@@ -263,7 +304,7 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: Literal[None] = None,
- callbacks: Iterable[Callable[[], None]] = ...,
+ callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@@ -272,7 +313,7 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: T,
- callbacks: Iterable[Callable[[], None]] = ...,
+ callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@@ -281,13 +322,13 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: Optional[T] = None,
- callbacks: Iterable[Callable[[], None]] = (),
+ callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
- node.callbacks.update(callbacks)
+ node.add_callbacks(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
return node.value
@@ -303,10 +344,8 @@ class LruCache(Generic[KT, VT]):
# We sometimes store large objects, e.g. dicts, which cause
# the inequality check to take a long time. So let's only do
# the check if we have some callbacks to call.
- if node.callbacks and value != node.value:
- for cb in node.callbacks:
- cb()
- node.callbacks.clear()
+ if value != node.value:
+ node.run_and_clear_callbacks()
# We don't bother to protect this by value != node.value as
# generally size_callback will be cheap compared with equality
@@ -316,7 +355,7 @@ class LruCache(Generic[KT, VT]):
cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
- node.callbacks.update(callbacks)
+ node.add_callbacks(callbacks)
move_node_to_front(node)
node.value = value
@@ -369,8 +408,7 @@ class LruCache(Generic[KT, VT]):
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
- for cb in node.callbacks:
- cb()
+ node.run_and_clear_callbacks()
cache.clear()
if size_callback:
cached_cache_len[0] = 0
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
new file mode 100644
index 0000000000..8c13b4f693
--- /dev/null
+++ b/tests/http/test_site.py
@@ -0,0 +1,83 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet.address import IPv6Address
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.app.homeserver import SynapseHomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SynapseRequestTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
+
+ def test_large_request(self):
+ """overlarge HTTP requests should be rejected"""
+ self.hs.start_listening()
+
+ # find the HTTP server which is configured to listen on port 0
+ (port, factory, _backlog, interface) = self.reactor.tcpServers[0]
+ self.assertEqual(interface, "::")
+ self.assertEqual(port, 0)
+
+ # as a control case, first send a regular request.
+
+ # complete the connection and wire it up to a fake transport
+ client_address = IPv6Address("TCP", "::1", "2345")
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+
+ while not transport.disconnecting:
+ self.reactor.advance(1)
+
+ # we should get a 404
+ self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
+
+ # now send an oversized request
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ )
+
+ # we deliberately send all the data in one big chunk, to ensure that
+ # twisted isn't buffering the data in the chunked transfer decoder.
+ # we start with the chunk size, in hex. (We won't actually send this much)
+ protocol.dataReceived(b"10000000\r\n")
+ sent = 0
+ while not transport.disconnected:
+ self.assertLess(sent, 0x10000000, "connection did not drop")
+ protocol.dataReceived(b"\0" * 1024)
+ sent += 1024
+
+ # default max upload size is 50M, so it should drop on the next buffer after
+ # that.
+ self.assertEqual(sent, 50 * 1024 * 1024 + 1024)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c9d04aef29..624bd1b927 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Dict, List, Optional, Tuple
-from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
-from twisted.internet.task import LoopingCall
-from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
-from twisted.web.server import Request, Site
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
@@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
+ channel = self.site.buildProtocol(None)
+
+ # hook into the channel's request factory so that we can keep a record
+ # of the requests
+ requests: List[SynapseRequest] = []
+ real_request_factory = channel.requestFactory
+
+ def request_factory(*args, **kwargs):
+ request = real_request_factory(*args, **kwargs)
+ requests.append(request)
+ return request
+
+ channel.requestFactory = request_factory
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return channel.request
+ # there should have been exactly one request
+ self.assertEqual(len(requests), 1)
+
+ return requests[0]
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
+ max_request_body_size=4096,
+ reactor=self.reactor,
)
if worker_hs.config.redis.redis_enabled:
@@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
+ channel = self._hs_to_site[hs].buildProtocol(None)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
self.received_rdata_rows.append((stream_name, token, r))
-class _PushHTTPChannel(HTTPChannel):
- """A HTTPChannel that wraps pull producers to push producers.
-
- This is a hack to get around the fact that HTTPChannel transparently wraps a
- pull producer (which is what Synapse uses to reply to requests) with
- `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
- uses the standard reactor rather than letting us use our test reactor, which
- makes it very hard to test.
- """
-
- def __init__(
- self, reactor: IReactorTime, request_factory: Type[Request], site: Site
- ):
- super().__init__()
- self.reactor = reactor
- self.requestFactory = request_factory
- self.site = site
-
- self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
-
- def registerProducer(self, producer, streaming):
- # Convert pull producers to push producer.
- if not streaming:
- self._pull_to_push_producer = _PullToPushProducer(
- self.reactor, producer, self
- )
- producer = self._pull_to_push_producer
-
- super().registerProducer(producer, True)
-
- def unregisterProducer(self):
- if self._pull_to_push_producer:
- # We need to manually stop the _PullToPushProducer.
- self._pull_to_push_producer.stop()
-
- def checkPersistence(self, request, version):
- """Check whether the connection can be re-used"""
- # We hijack this to always say no for ease of wiring stuff up in
- # `handle_http_replication_attempt`.
- request.responseHeaders.setRawHeaders(b"connection", [b"close"])
- return False
-
- def requestDone(self, request):
- # Store the request for inspection.
- self.request = request
- super().requestDone(request)
-
-
-class _PullToPushProducer:
- """A push producer that wraps a pull producer."""
-
- def __init__(
- self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
- ):
- self._clock = Clock(reactor)
- self._producer = producer
- self._consumer = consumer
-
- # While running we use a looping call with a zero delay to call
- # resumeProducing on given producer.
- self._looping_call = None # type: Optional[LoopingCall]
-
- # We start writing next reactor tick.
- self._start_loop()
-
- def _start_loop(self):
- """Start the looping call to"""
-
- if not self._looping_call:
- # Start a looping call which runs every tick.
- self._looping_call = self._clock.looping_call(self._run_once, 0)
-
- def stop(self):
- """Stops calling resumeProducing."""
- if self._looping_call:
- self._looping_call.stop()
- self._looping_call = None
-
- def pauseProducing(self):
- """Implements IPushProducer"""
- self.stop()
-
- def resumeProducing(self):
- """Implements IPushProducer"""
- self._start_loop()
-
- def stopProducing(self):
- """Implements IPushProducer"""
- self.stop()
- self._producer.stopProducing()
-
- def _run_once(self):
- """Calls resumeProducing on producer once."""
-
- try:
- self._producer.resumeProducing()
- except Exception:
- logger.exception("Failed to call resumeProducing")
- try:
- self._consumer.unregisterProducer()
- except Exception:
- pass
-
- self.stopProducing()
-
-
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
diff --git a/tests/server.py b/tests/server.py
index b535a5d886..9df8cda24f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -603,12 +603,6 @@ class FakeTransport:
if self.disconnected:
return
- if not hasattr(self.other, "transport"):
- # the other has no transport yet; reschedule
- if self.autoflush:
- self._reactor.callLater(0.0, self.flush)
- return
-
if maxbytes is not None:
to_write = self.buffer[:maxbytes]
else:
diff --git a/tests/test_server.py b/tests/test_server.py
index 55cde7f62f..407e172e41 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}),
self.resource,
"1.0",
+ max_request_body_size=1234,
+ reactor=self.reactor,
)
# render the request and return the channel
diff --git a/tests/unittest.py b/tests/unittest.py
index ee22a53849..9bd02bd9c4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -247,6 +247,8 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
+ max_request_body_size=1234,
+ reactor=self.reactor,
)
from tests.rest.client.v1.utils import RestHelper
|