diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 2f97e6d258..c7877b4095 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -40,7 +40,7 @@ class CasConfig(Config):
self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs):
- return """
+ return """\
# Enable Central Authentication Service (CAS) for registration and login.
#
cas_config:
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index fddca19223..c7fa749377 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
# limitations under the License.
import string
-from typing import Optional, Type
+from typing import Iterable, Optional, Type
import attr
@@ -33,16 +33,8 @@ class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
- validate_config(MAIN_CONFIG_SCHEMA, config, ())
-
- self.oidc_provider = None # type: Optional[OidcProviderConfig]
-
- oidc_config = config.get("oidc_config")
- if oidc_config and oidc_config.get("enabled", False):
- validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
- self.oidc_provider = _parse_oidc_config_dict(oidc_config)
-
- if not self.oidc_provider:
+ self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
+ if not self.oidc_providers:
return
try:
@@ -58,144 +50,153 @@ class OIDCConfig(Config):
@property
def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider
- return bool(self.oidc_provider)
+ return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
- # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+ # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+ # and login.
+ #
+ # Options for each entry include:
+ #
+ # idp_id: a unique identifier for this identity provider. Used internally
+ # by Synapse; should be a single word such as 'github'.
+ #
+ # Note that, if this is changed, users authenticating via that provider
+ # will no longer be recognised as the same user!
+ #
+ # idp_name: A user-facing name for this identity provider, which is used to
+ # offer the user a choice of login mechanisms.
+ #
+ # discover: set to 'false' to disable the use of the OIDC discovery mechanism
+ # to discover endpoints. Defaults to true.
+ #
+ # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+ # is enabled) to discover the provider's endpoints.
+ #
+ # client_id: Required. oauth2 client id to use.
+ #
+ # client_secret: Required. oauth2 client secret to use.
+ #
+ # client_auth_method: auth method to use when exchanging the token. Valid
+ # values are 'client_secret_basic' (default), 'client_secret_post' and
+ # 'none'.
+ #
+ # scopes: list of scopes to request. This should normally include the "openid"
+ # scope. Defaults to ["openid"].
+ #
+ # authorization_endpoint: the oauth2 authorization endpoint. Required if
+ # provider discovery is disabled.
+ #
+ # token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+ # disabled.
+ #
+ # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+ # disabled and the 'openid' scope is not requested.
+ #
+ # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+ # the 'openid' scope is used.
+ #
+ # skip_verification: set to 'true' to skip metadata verification. Use this if
+ # you are connecting to a provider that is not OpenID Connect compliant.
+ # Defaults to false. Avoid this in production.
+ #
+ # user_profile_method: Whether to fetch the user profile from the userinfo
+ # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+ #
+ # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+ # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+ # userinfo endpoint.
+ #
+ # allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+ # match a pre-existing account instead of failing. This could be used if
+ # switching from password logins to OIDC. Defaults to false.
+ #
+ # user_mapping_provider: Configuration for how attributes returned from a OIDC
+ # provider are mapped onto a matrix user. This setting has the following
+ # sub-properties:
+ #
+ # module: The class name of a custom mapping module. Default is
+ # {mapping_provider!r}.
+ # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+ # for information on implementing a custom mapping provider.
+ #
+ # config: Configuration for the mapping provider module. This section will
+ # be passed as a Python dictionary to the user mapping provider
+ # module's `parse_config` method.
+ #
+ # For the default provider, the following settings are available:
+ #
+ # sub: name of the claim containing a unique identifier for the
+ # user. Defaults to 'sub', which OpenID Connect compliant
+ # providers should provide.
+ #
+ # localpart_template: Jinja2 template for the localpart of the MXID.
+ # If this is not set, the user will be prompted to choose their
+ # own username.
+ #
+ # display_name_template: Jinja2 template for the display name to set
+ # on first login. If unset, no displayname will be set.
+ #
+ # extra_attributes: a map of Jinja2 templates for extra attributes
+ # to send back to the client during login.
+ # Note that these are non-standard and clients will ignore them
+ # without modifications.
+ #
+ # When rendering, the Jinja2 templates are given a 'user' variable,
+ # which is set to the claims returned by the UserInfo Endpoint and/or
+ # in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
- # for some example configurations.
+ # for information on how to configure these options.
#
- oidc_config:
- # Uncomment the following to enable authorization against an OpenID Connect
- # server. Defaults to false.
- #
- #enabled: true
-
- # Uncomment the following to disable use of the OIDC discovery mechanism to
- # discover endpoints. Defaults to true.
- #
- #discover: false
-
- # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
- # discover the provider's endpoints.
- #
- # Required if 'enabled' is true.
- #
- #issuer: "https://accounts.example.com/"
-
- # oauth2 client id to use.
- #
- # Required if 'enabled' is true.
- #
- #client_id: "provided-by-your-issuer"
-
- # oauth2 client secret to use.
- #
- # Required if 'enabled' is true.
- #
- #client_secret: "provided-by-your-issuer"
-
- # auth method to use when exchanging the token.
- # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
- # 'none'.
- #
- #client_auth_method: client_secret_post
-
- # list of scopes to request. This should normally include the "openid" scope.
- # Defaults to ["openid"].
- #
- #scopes: ["openid", "profile"]
-
- # the oauth2 authorization endpoint. Required if provider discovery is disabled.
- #
- #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
-
- # the oauth2 token endpoint. Required if provider discovery is disabled.
- #
- #token_endpoint: "https://accounts.example.com/oauth2/token"
-
- # the OIDC userinfo endpoint. Required if discovery is disabled and the
- # "openid" scope is not requested.
- #
- #userinfo_endpoint: "https://accounts.example.com/userinfo"
-
- # URI where to fetch the JWKS. Required if discovery is disabled and the
- # "openid" scope is used.
- #
- #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
-
- # Uncomment to skip metadata verification. Defaults to false.
- #
- # Use this if you are connecting to a provider that is not OpenID Connect
- # compliant.
- # Avoid this in production.
- #
- #skip_verification: true
-
- # Whether to fetch the user profile from the userinfo endpoint. Valid
- # values are: "auto" or "userinfo_endpoint".
+ # For backwards compatibility, it is also possible to configure a single OIDC
+ # provider via an 'oidc_config' setting. This is now deprecated and admins are
+ # advised to migrate to the 'oidc_providers' format.
+ #
+ oidc_providers:
+ # Generic example
#
- # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
- # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+ #- idp_id: my_idp
+ # idp_name: "My OpenID provider"
+ # discover: false
+ # issuer: "https://accounts.example.com/"
+ # client_id: "provided-by-your-issuer"
+ # client_secret: "provided-by-your-issuer"
+ # client_auth_method: client_secret_post
+ # scopes: ["openid", "profile"]
+ # authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+ # token_endpoint: "https://accounts.example.com/oauth2/token"
+ # userinfo_endpoint: "https://accounts.example.com/userinfo"
+ # jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+ # skip_verification: true
+
+ # For use with Keycloak
#
- #user_profile_method: "userinfo_endpoint"
-
- # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
- # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
- #
- #allow_existing_users: true
-
- # An external module can be provided here as a custom solution to mapping
- # attributes returned from a OIDC provider onto a matrix user.
+ #- idp_id: keycloak
+ # idp_name: Keycloak
+ # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
+ # client_id: "synapse"
+ # client_secret: "copy secret generated in Keycloak UI"
+ # scopes: ["openid", "profile"]
+
+ # For use with Github
#
- user_mapping_provider:
- # The custom module's class. Uncomment to use a custom module.
- # Default is {mapping_provider!r}.
- #
- # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
- # for information on implementing a custom mapping provider.
- #
- #module: mapping_provider.OidcMappingProvider
-
- # Custom configuration values for the module. This section will be passed as
- # a Python dictionary to the user mapping provider module's `parse_config`
- # method.
- #
- # The examples below are intended for the default provider: they should be
- # changed if using a custom provider.
- #
- config:
- # name of the claim containing a unique identifier for the user.
- # Defaults to `sub`, which OpenID Connect compliant providers should provide.
- #
- #subject_claim: "sub"
-
- # Jinja2 template for the localpart of the MXID.
- #
- # When rendering, this template is given the following variables:
- # * user: The claims returned by the UserInfo Endpoint and/or in the ID
- # Token
- #
- # If this is not set, the user will be prompted to choose their
- # own username.
- #
- #localpart_template: "{{{{ user.preferred_username }}}}"
-
- # Jinja2 template for the display name to set on first login.
- #
- # If unset, no displayname will be set.
- #
- #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
-
- # Jinja2 templates for extra attributes to send back to the client during
- # login.
- #
- # Note that these are non-standard and clients will ignore them without modifications.
- #
- #extra_attributes:
- #birthdate: "{{{{ user.birthdate }}}}"
+ #- idp_id: google
+ # idp_name: Google
+ # discover: false
+ # issuer: "https://github.com/"
+ # client_id: "your-client-id" # TO BE FILLED
+ # client_secret: "your-client-secret" # TO BE FILLED
+ # authorization_endpoint: "https://github.com/login/oauth/authorize"
+ # token_endpoint: "https://github.com/login/oauth/access_token"
+ # userinfo_endpoint: "https://api.github.com/user"
+ # scopes: ["read:user"]
+ # user_mapping_provider:
+ # config:
+ # subject_claim: "id"
+ # localpart_template: "{{ user.login }}"
+ # display_name_template: "{{ user.name }}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
@@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
},
}
-# the `oidc_config` setting can either be None (as it is in the default
+# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
+OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
+ "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
+}
+
+
+# the `oidc_providers` list can either be None (as it is in the default config), or
+# a list of provider configs, each of which requires an explicit ID and name.
+OIDC_PROVIDER_LIST_SCHEMA = {
+ "oneOf": [
+ {"type": "null"},
+ {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
+ ]
+}
+
+# the `oidc_config` setting can either be None (which it used to be in the default
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
# property.
#
@@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
# additional checks in the code.
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
+# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
MAIN_CONFIG_SCHEMA = {
"type": "object",
- "properties": {"oidc_config": OIDC_CONFIG_SCHEMA},
+ "properties": {
+ "oidc_config": OIDC_CONFIG_SCHEMA,
+ "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
+ },
}
+def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
+ """extract and parse the OIDC provider configs from the config dict
+
+ The configuration may contain either a single `oidc_config` object with an
+ `enabled: True` property, or a list of provider configurations under
+ `oidc_providers`, *or both*.
+
+ Returns a generator which yields the OidcProviderConfig objects
+ """
+ validate_config(MAIN_CONFIG_SCHEMA, config, ())
+
+ for p in config.get("oidc_providers") or []:
+ yield _parse_oidc_config_dict(p)
+
+ # for backwards-compatibility, it is also possible to provide a single "oidc_config"
+ # object with an "enabled: True" property.
+ oidc_config = config.get("oidc_config")
+ if oidc_config and oidc_config.get("enabled", False):
+ # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
+ # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
+ # above), so now we need to validate it.
+ validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
+ yield _parse_oidc_config_dict(oidc_config)
+
+
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index cc5f75123c..740c3fc1b1 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -14,14 +14,13 @@
# limitations under the License.
import os
-from distutils.util import strtobool
import pkg_resources
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
-from synapse.util.stringutils import random_string_with_symbols
+from synapse.util.stringutils import random_string_with_symbols, strtobool
class AccountValidityConfig(Config):
@@ -86,12 +85,12 @@ class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
- self.enable_registration = bool(
- strtobool(str(config.get("enable_registration", False)))
+ self.enable_registration = strtobool(
+ str(config.get("enable_registration", False))
)
if "disable_registration" in config:
- self.enable_registration = not bool(
- strtobool(str(config["disable_registration"]))
+ self.enable_registration = not strtobool(
+ str(config["disable_registration"])
)
self.account_validity = AccountValidityConfig(
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8028663fa8..3ec4120f85 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -17,7 +17,6 @@
import abc
import os
-from distutils.util import strtobool
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
+from synapse.util.stringutils import strtobool
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
@@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself.
+
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 109dc7932f..37a678b6ce 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -163,7 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
- run_in_background(self._user_device_resync, sender_user_id)
+ run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f63a90ec5c..5e5fda7b2f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -78,21 +78,28 @@ class OidcHandler:
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
- provider_conf = hs.config.oidc.oidc_provider
+ provider_confs = hs.config.oidc.oidc_providers
# we should not have been instantiated if there is no configured provider.
- assert provider_conf is not None
+ assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
-
- self._provider = OidcProvider(hs, self._token_generator, provider_conf)
+ self._providers = {
+ p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+ }
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need.
"""
- await self._provider.load_metadata()
- await self._provider.load_jwks()
+ for idp_id, p in self._providers.items():
+ try:
+ await p.load_metadata()
+ await p.load_jwks()
+ except Exception as e:
+ raise Exception(
+ "Error while initialising OIDC provider %r" % (idp_id,)
+ ) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
@@ -184,6 +191,12 @@ class OidcHandler:
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
+ oidc_provider = self._providers.get(session_data.idp_id)
+ if not oidc_provider:
+ logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+ self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+ return
+
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
@@ -193,7 +206,7 @@ class OidcHandler:
code = request.args[b"code"][0].decode()
- await self._provider.handle_oidc_callback(request, session_data, code)
+ await oidc_provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index dc4b81ca60..df498c8645 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
+ # If the deferred was called, bail early.
+ if self.deferred.called:
+ return
+
self.stream.write(data)
self.length += len(data)
+ # The first time the maximum size is exceeded, error and cancel the
+ # connection. dataReceived might be called again if data was received
+ # in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
- self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None:
+ # If the maximum size was already exceeded, there's nothing to do.
+ if self.deferred.called:
+ return
+
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id: str):
+ async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, user_id: str):
+ async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, server_name: str, media_id: str):
+ async def on_POST(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
+class ProtectMediaByID(RestServlet):
+ """Protect local media from being quarantined.
+ """
+
+ PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Protecting local media by ID: %s", media_id)
+
+ # Quarantine this media id
+ await self.store.mark_local_media_as_safe(media_id)
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_DELETE(self, request, server_name: str, media_id: str):
+ async def on_DELETE(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request, server_name: str):
+ async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
+ ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 47c2b44bff..31a41e4a27 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
import logging
import os
import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
]
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
@@ -69,7 +70,7 @@ def parse_media_id(request):
)
-def respond_404(request):
+def respond_404(request: Request) -> None:
respond_with_json(
request,
404,
@@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file(
- request, media_type, file_path, file_size=None, upload_name=None
-):
+ request: Request,
+ media_type: str,
+ file_path: str,
+ file_size: Optional[int] = None,
+ upload_name: Optional[str] = None,
+) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request)
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+ request: Request,
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str],
+) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
- request (twisted.web.http.Request)
- media_type (str): The media/content type.
- file_size (int): Size in bytes of the media, if known.
- upload_name (str): The name of the requested file, if any.
+ request
+ media_type: The media/content type.
+ file_size: Size in bytes of the media, if known.
+ upload_name: The name of the requested file, if any.
"""
def _quote(x):
@@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
- request.setHeader(b"Content-Length", b"%d" % (file_size,))
+ if file_size is not None:
+ request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This
# should help to prevent things in the media repo from showing up in web
@@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
}
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
@@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder(
- request, responder, media_type, file_size, upload_name=None
-):
+ request: Request,
+ responder: "Optional[Responder]",
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str] = None,
+) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
- request (twisted.web.http.Request)
- responder (Responder|None)
- media_type (str): The media/content type.
- file_size (int|None): Size in bytes of the media. If not known it should be None
- upload_name (str|None): The name of the requested file, if any.
+ request
+ responder
+ media_type: The media/content type.
+ file_size: Size in bytes of the media. If not known it should be None
+ upload_name: The name of the requested file, if any.
"""
if request._disconnected:
logger.warning(
@@ -308,22 +323,22 @@ class FileInfo:
self.thumbnail_type = thumbnail_type
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
- headers (dict[bytes, list[bytes]]): The HTTP request headers.
+ headers: The HTTP request headers.
Returns:
- A Unicode string of the filename, or None.
+ The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
- return
+ return None
_, params = _parse_header(content_disposition[0])
@@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
return upload_name
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
- line (bytes): header to be parsed
+ line: header to be parsed
Returns:
- Tuple[bytes, dict[bytes, bytes]]:
- the main content-type, followed by the parameter dictionary
+ The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
@@ -386,16 +400,16 @@ def _parse_header(line):
return key, pdict
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
- s (bytes): header to be parsed
+ s: header to be parsed
Returns:
- Iterable[bytes]: the split input
+ The split input
"""
while s[:1] == b";":
s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..4e4c6971f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +15,29 @@
# limitations under the License.
#
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
from synapse.http.server import DirectServeJsonResource, respond_with_json
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..3ed219ae43 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
-import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
else:
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True
- )
+ allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
import functools
import os
import re
+from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured)
"""
- def __init__(self, primary_base_path):
+ def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
def default_thumbnail_rel(
- self, default_top_level, default_sub_type, width, height, content_type, method
- ):
+ self,
+ default_top_level: str,
+ default_sub_type: str,
+ width: int,
+ height: int,
+ content_type: str,
+ method: str,
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
- def local_media_filepath_rel(self, media_id):
+ def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
- def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+ def local_media_thumbnail_rel(
+ self, media_id: str, width: int, height: int, content_type: str, method: str
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:],
)
- def remote_media_filepath_rel(self, server_name, file_id):
+ def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
@@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel(
- self, server_name, file_id, width, height, content_type, method
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ width: int,
+ height: int,
+ content_type: str,
+ method: str,
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
def remote_media_thumbnail_rel_legacy(
- self, server_name, file_id, width, height, content_type
+ self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
file_name,
)
- def remote_media_thumbnail_dir(self, server_name, file_id):
+ def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:],
)
- def url_cache_filepath_rel(self, media_id):
+ def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
- def url_cache_filepath_dirs_to_delete(self, media_id):
+ def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
- def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+ def url_cache_thumbnail_rel(
+ self, media_id: str, width: int, height: int, content_type: str, method: str
+ ) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
- def url_cache_thumbnail_directory(self, media_id):
+ def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:],
)
- def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 83beb02b05..4c9946a616 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import errno
import logging
import os
import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
@@ -73,16 +76,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.primary_base_path = hs.config.media_store_path
- self.filepaths = MediaFilePaths(self.primary_base_path)
+ self.primary_base_path = hs.config.media_store_path # type: str
+ self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
- self.recently_accessed_remotes = set()
- self.recently_accessed_locals = set()
+ self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
+ self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed
)
- async def _update_recently_accessed(self):
+ async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec()
)
- def mark_recently_accessed(self, server_name, media_id):
+ def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed.
Args:
- server_name (str|None): Origin server of media, or None if local
- media_id (str): The media ID of the content
+ server_name: Origin server of media, or None if local
+ media_id: The media ID of the content
"""
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
@@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+ def _generate_thumbnail(
+ self,
+ thumbnailer: Thumbnailer,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[BytesIO]:
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -470,22 +480,20 @@ class MediaRepository:
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
- t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+ return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
- else:
- t_byte_source = None
+ return thumbnailer.scale(t_width, t_height, t_type)
- return t_byte_source
+ return None
async def generate_local_exact_thumbnail(
self,
@@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height}
- async def delete_old_remote_media(self, before_ts):
+ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..89cdd605aa 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ import os
import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id)
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
@@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
- open_file (file): A file like object to be streamed ot the client,
+ open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
- def __init__(self, open_file):
+ def __init__(self, open_file: IO):
self.open_file = open_file
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 1082389d9b..a632099167 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import errno
import fnmatch
@@ -23,12 +23,13 @@ import re
import shutil
import sys
import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
+from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
+if TYPE_CHECKING:
+ from lxml import etree
+
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo, media_storage):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
super().__init__()
self.auth = hs.get_auth()
@@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
@@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url: str, user):
+ async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data
)
- async def _expire_url_cache_data(self):
+ async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
-def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+def decode_and_calc_og(
+ body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
# If there's no body, nothing useful is going to be found.
if not body:
return {}
@@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
return og
-def _calc_og(tree, media_uri):
+def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
@@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og["og:description"] = summarize_paragraphs(text_nodes)
- else:
+ elif og["og:description"]:
+ # This must be a non-empty string at this point.
+ assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
@@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
return og
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+ tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
@@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
)
-def _rebase_url(url, base):
- base = list(urlparse.urlparse(base))
- url = list(urlparse.urlparse(url))
- if not url[0]: # fix up schema
- url[0] = base[0] or "http"
- if not url[1]: # fix up hostname
- url[1] = base[1]
- if not url[2].startswith("/"):
- url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
- return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+ base_parts = list(urlparse.urlparse(base))
+ url_parts = list(urlparse.urlparse(url))
+ if not url_parts[0]: # fix up schema
+ url_parts[0] = base_parts[0] or "http"
+ if not url_parts[1]: # fix up hostname
+ url_parts[1] = base_parts[1]
+ if not url_parts[2].startswith("/"):
+ url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+ return urlparse.urlunparse(url_parts)
-def _is_media(content_type):
- if content_type.lower().startswith("image/"):
- return True
+def _is_media(content_type: str) -> bool:
+ return content_type.lower().startswith("image/")
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
content_type = content_type.lower()
- if content_type.startswith("text/html") or content_type.startswith(
+ return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml"
- ):
- return True
+ )
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+ text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 67f67efde7..e92006faa9 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
import os
import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
@@ -27,13 +28,17 @@ from .media_storage import FileResponder
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- async def store_file(self, path: str, file_info: FileInfo):
+ @abc.abstractmethod
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file.
"""
+ @abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
- def __str__(self):
+ def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
- async def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
@@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- return await maybe_awaitable(self.backend.store_file(path, file_info))
+ await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
@@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
- return None
- async def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
@@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
- hs (HomeServer)
+ hs
config: The config returned by `parse_config`.
"""
- def __init__(self, hs, config):
+ def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
@@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- async def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return await defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- async def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
+ return None
+
@staticmethod
- def parse_config(config):
+ def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..d6880f2e6e 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
respond_with_responder,
)
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo, media_storage):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
super().__init__()
self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
- self, request, media_id, width, height, method, m_type
- ):
+ self,
+ request: Request,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ ) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail(
self,
- request,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- ):
+ request: Request,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ ) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
- request,
- server_name,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- ):
+ request: Request,
+ server_name: str,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ ) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
- self, request, server_name, media_id, width, height, method, m_type
- ):
+ self,
+ request: Request,
+ server_name: str,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ ) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
@@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
def _select_thumbnail(
self,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
thumbnail_infos,
- ):
+ ) -> dict:
d_w = desired_width
d_h = desired_height
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..07903e4017 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# limitations under the License.
import logging
from io import BytesIO
+from typing import Tuple
from PIL import Image
@@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
- def __init__(self, input_path):
+ def __init__(self, input_path: str):
try:
self.image = Image.open(input_path)
except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
- def transpose(self):
+ def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag
Returns:
- Tuple[int, int]: (width, height) containing the new image size in pixels.
+ A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None
return self.image.size
- def aspect(self, max_width, max_height):
+ def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
@@ -91,7 +93,7 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
- def _resize(self, width, height):
+ def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
@@ -99,7 +101,7 @@ class Thumbnailer:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
- def scale(self, width, height, output_type):
+ def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions.
Returns:
@@ -108,7 +110,7 @@ class Thumbnailer:
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
- def crop(self, width, height, output_type):
+ def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +138,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
- def _encode_image(self, output_image, output_type):
+ def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 42febc9afc..6da76ae994 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,18 +15,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
@@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: Request) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@
import logging
from typing import Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
@@ -28,6 +30,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+ """Return value for _calculate_chain_cover_txn.
+ """
+
+ # The last room_id/depth/stream processed.
+ room_id = attr.ib(type=str)
+ depth = attr.ib(type=int)
+ stream = attr.ib(type=int)
+
+ # Number of rows processed
+ processed_count = attr.ib(type=int)
+
+ # Map from room_id to last depth/stream processed for each room that we have
+ # processed all events for (i.e. the rooms we can flip the
+ # `has_auth_chain_index` for)
+ finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
current_room_id = progress.get("current_room_id", "")
- # Have we finished processing the current room.
- finished = progress.get("finished", True)
-
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
- # Have we set the `has_auth_chain_index` for the room yet.
- has_set_room_has_chain_index = progress.get(
- "has_set_room_has_chain_index", False
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ current_room_id,
+ last_depth,
+ last_stream,
+ batch_size,
+ single_room=False,
)
- if finished:
- # If we've finished with the previous room (or its our first
- # iteration) we move on to the next room.
-
- def _get_next_room(txn: Cursor) -> Optional[str]:
- sql = """
- SELECT room_id FROM rooms
- WHERE room_id > ?
- AND (
- NOT has_auth_chain_index
- OR has_auth_chain_index IS NULL
- )
- ORDER BY room_id
- LIMIT 1
- """
- txn.execute(sql, (current_room_id,))
- row = txn.fetchone()
- if row:
- return row[0]
+ finished = result.processed_count == 0
- return None
-
- current_room_id = await self.db_pool.runInteraction(
- "_chain_cover_index", _get_next_room
- )
- if not current_room_id:
- await self.db_pool.updates._end_background_update("chain_cover")
- return 0
-
- logger.debug("Adding chain cover to %s", current_room_id)
-
- def _calculate_auth_chain(
- txn: Cursor, last_depth: int, last_stream: int
- ) -> Tuple[int, int, int]:
- # Get the next set of events in the room (that we haven't already
- # computed chain cover for). We do this in topological order.
-
- # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
- # comparison, but that is not supported on older SQLite versions
- tuple_clause, tuple_args = make_tuple_comparison_clause(
- self.database_engine,
- [
- ("topological_ordering", last_depth),
- ("stream_ordering", last_stream),
- ],
- )
+ total_rows_processed = result.processed_count
+ current_room_id = result.room_id
+ last_depth = result.depth
+ last_stream = result.stream
- sql = """
- SELECT
- event_id, state_events.type, state_events.state_key,
- topological_ordering, stream_ordering
- FROM events
- INNER JOIN state_events USING (event_id)
- LEFT JOIN event_auth_chains USING (event_id)
- LEFT JOIN event_auth_chain_to_calculate USING (event_id)
- WHERE events.room_id = ?
- AND event_auth_chains.event_id IS NULL
- AND event_auth_chain_to_calculate.event_id IS NULL
- AND %(tuple_cmp)s
- ORDER BY topological_ordering, stream_ordering
- LIMIT ?
- """ % {
- "tuple_cmp": tuple_clause,
- }
-
- args = [current_room_id]
- args.extend(tuple_args)
- args.append(batch_size)
-
- txn.execute(sql, args)
- rows = txn.fetchall()
-
- # Put the results in the necessary format for
- # `_add_chain_cover_index`
- event_to_room_id = {row[0]: current_room_id for row in rows}
- event_to_types = {row[0]: (row[1], row[2]) for row in rows}
-
- new_last_depth = rows[-1][3] if rows else last_depth # type: int
- new_last_stream = rows[-1][4] if rows else last_stream # type: int
-
- count = len(rows)
-
- # We also need to fetch the auth events for them.
- auth_events = self.db_pool.simple_select_many_txn(
- txn,
- table="event_auth",
- column="event_id",
- iterable=event_to_room_id,
- keyvalues={},
- retcols=("event_id", "auth_id"),
- )
-
- event_to_auth_chain = {} # type: Dict[str, List[str]]
- for row in auth_events:
- event_to_auth_chain.setdefault(row["event_id"], []).append(
- row["auth_id"]
- )
-
- # Calculate and persist the chain cover index for this set of events.
- #
- # Annoyingly we need to gut wrench into the persit event store so that
- # we can reuse the function to calculate the chain cover for rooms.
- PersistEventsStore._add_chain_cover_index(
- txn,
- self.db_pool,
- event_to_room_id,
- event_to_types,
- event_to_auth_chain,
- )
-
- return new_last_depth, new_last_stream, count
-
- last_depth, last_stream, count = await self.db_pool.runInteraction(
- "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
- )
-
- total_rows_processed = count
-
- if count < batch_size and not has_set_room_has_chain_index:
+ for room_id, (depth, stream) in result.finished_room_map.items():
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
@@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.simple_update(
table="rooms",
- keyvalues={"room_id": current_room_id},
+ keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
- has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
- last_depth, last_stream, count = await self.db_pool.runInteraction(
- "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ room_id,
+ depth,
+ stream,
+ batch_size=None,
+ single_room=True,
)
- total_rows_processed += count
+ total_rows_processed += result.processed_count
- # Note that at this point its technically possible that more events
- # than our `batch_size` have been persisted without their chain
- # cover, so we need to continue processing this room if the last
- # count returned was equal to the `batch_size`.
+ if finished:
+ await self.db_pool.updates._end_background_update("chain_cover")
+ return total_rows_processed
- if count < batch_size:
- # We've finished calculating the index for this room, move on to the
- # next room.
- await self.db_pool.updates._background_update_progress(
- "chain_cover", {"current_room_id": current_room_id, "finished": True},
- )
- else:
- # We still have outstanding events to calculate the index for.
- await self.db_pool.updates._background_update_progress(
- "chain_cover",
- {
- "current_room_id": current_room_id,
- "last_depth": last_depth,
- "last_stream": last_stream,
- "has_auth_chain_index": has_set_room_has_chain_index,
- "finished": False,
- },
- )
+ await self.db_pool.updates._background_update_progress(
+ "chain_cover",
+ {
+ "current_room_id": current_room_id,
+ "last_depth": last_depth,
+ "last_stream": last_stream,
+ },
+ )
return total_rows_processed
+
+ def _calculate_chain_cover_txn(
+ self,
+ txn: Cursor,
+ last_room_id: str,
+ last_depth: int,
+ last_stream: int,
+ batch_size: Optional[int],
+ single_room: bool,
+ ) -> _CalculateChainCover:
+ """Calculate the chain cover for `batch_size` events, ordered by
+ `(room_id, depth, stream)`.
+
+ Args:
+ txn,
+ last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+ tuple to fetch results after.
+ batch_size: The maximum number of events to process. If None then
+ no limit.
+ single_room: Whether to calculate the index for just the given
+ room.
+ """
+
+ # Get the next set of events in the room (that we haven't already
+ # computed chain cover for). We do this in topological order.
+
+ # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+ # comparison, but that is not supported on older SQLite versions
+ tuple_clause, tuple_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [
+ ("events.room_id", last_room_id),
+ ("topological_ordering", last_depth),
+ ("stream_ordering", last_stream),
+ ],
+ )
+
+ extra_clause = ""
+ if single_room:
+ extra_clause = "AND events.room_id = ?"
+ tuple_args.append(last_room_id)
+
+ sql = """
+ SELECT
+ event_id, state_events.type, state_events.state_key,
+ topological_ordering, stream_ordering,
+ events.room_id
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+ WHERE event_auth_chains.event_id IS NULL
+ AND event_auth_chain_to_calculate.event_id IS NULL
+ AND %(tuple_cmp)s
+ %(extra)s
+ ORDER BY events.room_id, topological_ordering, stream_ordering
+ %(limit)s
+ """ % {
+ "tuple_cmp": tuple_clause,
+ "limit": "LIMIT ?" if batch_size is not None else "",
+ "extra": extra_clause,
+ }
+
+ if batch_size is not None:
+ tuple_args.append(batch_size)
+
+ txn.execute(sql, tuple_args)
+ rows = txn.fetchall()
+
+ # Put the results in the necessary format for
+ # `_add_chain_cover_index`
+ event_to_room_id = {row[0]: row[5] for row in rows}
+ event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+ # Calculate the new last position we've processed up to.
+ new_last_depth = rows[-1][3] if rows else last_depth # type: int
+ new_last_stream = rows[-1][4] if rows else last_stream # type: int
+ new_last_room_id = rows[-1][5] if rows else "" # type: str
+
+ # Map from room_id to last depth/stream_ordering processed for the room,
+ # excluding the last room (which we're likely still processing). We also
+ # need to include the room passed in if it's not included in the result
+ # set (as we then know we've processed all events in said room).
+ #
+ # This is the set of rooms that we can now safely flip the
+ # `has_auth_chain_index` bit for.
+ finished_rooms = {
+ row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+ }
+ if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+ finished_rooms[last_room_id] = (last_depth, last_stream)
+
+ count = len(rows)
+
+ # We also need to fetch the auth events for them.
+ auth_events = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ )
+
+ event_to_auth_chain = {} # type: Dict[str, List[str]]
+ for row in auth_events:
+ event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+ # Calculate and persist the chain cover index for this set of events.
+ #
+ # Annoyingly we need to gut wrench into the persit event store so that
+ # we can reuse the function to calculate the chain cover for rooms.
+ PersistEventsStore._add_chain_cover_index(
+ txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ )
+
+ return _CalculateChainCover(
+ room_id=new_last_room_id,
+ depth=new_last_depth,
+ stream=new_last_stream,
+ processed_count=count,
+ finished_room_map=finished_rooms,
+ )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4b2f224718..283c8a5e22 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_local_media_before(
self, before_ts: int, size_gt: int, keep_profiles: bool,
- ) -> Optional[List[str]]:
+ ) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 77ba9d819e..bc7621b8d6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -17,14 +17,13 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
-from canonicaljson import encode_canonical_json
-
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
- "data": bytearray(encode_canonical_json(data)),
+ "data": json_encoder.encode(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..b103c8694c 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
if len(items) <= maxitems:
return str(items)
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False
+
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+ are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
+ 'val' is anything else.
+
+ This is lifted from distutils.util.strtobool, with the exception that it actually
+ returns a bool, rather than an int.
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ elif val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ else:
+ raise ValueError("invalid truth value %r" % (val,))
|