diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index a3b95f4de0..08fe160c98 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -81,7 +81,7 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of
# users.
return
- elif requester.app_service and not self._track_appservice_user_ips:
+ if requester.app_service and not self._track_appservice_user_ips:
# If we're authenticated as an appservice then we only block
# auth if `track_appservice_user_ips` is set, as that option
# implicitly means that application services are part of MAU
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index d3270cd6d2..032c69b210 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -39,12 +39,12 @@ class ConsentURIBuilder:
Args:
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
- if hs_config.form_secret is None:
+ if hs_config.key.form_secret is None:
raise ConfigError("form_secret not set in config")
if hs_config.server.public_baseurl is None:
raise ConfigError("public_baseurl not set in config")
- self._hmac_secret = hs_config.form_secret.encode("utf-8")
+ self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
self._public_baseurl = hs_config.server.public_baseurl
def build_user_consent_uri(self, user_id):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index d1aa2e7fb5..548f6dcde9 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -88,8 +88,8 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
appname,
soft_file_limit=config.soft_file_limit,
gc_thresholds=config.gc_thresholds,
- pid_file=config.worker_pid_file,
- daemonize=config.worker_daemonize,
+ pid_file=config.worker.worker_pid_file,
+ daemonize=config.worker.worker_daemonize,
print_pidfile=config.print_pidfile,
logger=logger,
run_command=run_command,
@@ -424,12 +424,14 @@ def setup_sentry(hs):
hs (synapse.server.HomeServer)
"""
- if not hs.config.sentry_enabled:
+ if not hs.config.metrics.sentry_enabled:
return
import sentry_sdk
- sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
+ sentry_sdk.init(
+ dsn=hs.config.metrics.sentry_dsn, release=get_version_string(synapse)
+ )
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 5e956b1e27..f2c5b75247 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -186,13 +186,13 @@ def start(config_options):
config.worker.worker_app = "synapse.app.admin_cmd"
if (
- not config.worker_daemonize
- and not config.worker_log_file
- and not config.worker_log_config
+ not config.worker.worker_daemonize
+ and not config.worker.worker_log_file
+ and not config.worker.worker_log_config
):
# Since we're meant to be run as a "command" let's not redirect stdio
# unless we've actually set log config.
- config.no_redirect_stdio = True
+ config.logging.no_redirect_stdio = True
# Explicitly disable background processes
config.update_user_directory = False
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 33afd59c72..3036e1b4a0 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -140,7 +140,7 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.http_client = hs.get_simple_http_client()
- self.main_uri = hs.config.worker_main_http_uri
+ self.main_uri = hs.config.worker.worker_main_http_uri
async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -321,7 +321,7 @@ class GenericWorkerServer(HomeServer):
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
- if self.config.can_load_media_repo:
+ if self.config.media.can_load_media_repo:
media_repo = self.get_media_repository_resource()
# We need to serve the admin servlets for media on the
@@ -384,7 +384,7 @@ class GenericWorkerServer(HomeServer):
logger.info("Synapse worker now listening on port %d", port)
def start_listening(self):
- for listener in self.config.worker_listeners:
+ for listener in self.config.worker.worker_listeners:
if listener.type == "http":
self._listen_http(listener)
elif listener.type == "manhole":
@@ -395,7 +395,7 @@ class GenericWorkerServer(HomeServer):
manhole_globals={"hs": self},
)
elif listener.type == "metrics":
- if not self.config.enable_metrics:
+ if not self.config.metrics.enable_metrics:
logger.warning(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -488,7 +488,7 @@ def start(config_options):
register_start(_base.start, hs)
# redirect stdio to the logs, if configured.
- if not hs.config.no_redirect_stdio:
+ if not hs.config.logging.no_redirect_stdio:
redirect_stdio_to_logs()
_base.start_worker_reactor("synapse-generic-worker", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b909f8db8d..205831dcda 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -195,7 +195,7 @@ class SynapseHomeServer(HomeServer):
}
)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.rest.synapse.client.password_reset import (
PasswordResetSubmitTokenResource,
)
@@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer):
)
if name in ["media", "federation", "client"]:
- if self.config.enable_media_repo:
+ if self.config.media.enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update(
{MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
@@ -269,7 +269,7 @@ class SynapseHomeServer(HomeServer):
# https://twistedmatrix.com/trac/ticket/7678
resources[WEB_CLIENT_PREFIX] = File(webclient_loc)
- if name == "metrics" and self.config.enable_metrics:
+ if name == "metrics" and self.config.metrics.enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
if name == "replication":
@@ -278,7 +278,7 @@ class SynapseHomeServer(HomeServer):
return resources
def start_listening(self):
- if self.config.redis_enabled:
+ if self.config.redis.redis_enabled:
# If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client
# rather than a server).
@@ -305,7 +305,7 @@ class SynapseHomeServer(HomeServer):
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener.type == "metrics":
- if not self.config.enable_metrics:
+ if not self.config.metrics.enable_metrics:
logger.warning(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -366,7 +366,7 @@ def setup(config_options):
async def start():
# Load the OIDC provider metadatas, if OIDC is enabled.
- if hs.config.oidc_enabled:
+ if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
@@ -455,7 +455,7 @@ def main():
hs = setup(sys.argv[1:])
# redirect stdio to the logs, if configured.
- if not hs.config.no_redirect_stdio:
+ if not hs.config.logging.no_redirect_stdio:
redirect_stdio_to_logs()
run(hs)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 4a95da90f9..49e7a45e5c 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -131,10 +131,12 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
log_level = synapse_logger.getEffectiveLevel()
stats["log_level"] = logging.getLevelName(log_level)
- logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ logger.info(
+ "Reporting stats to %s: %s" % (hs.config.metrics.report_stats_endpoint, stats)
+ )
try:
await hs.get_proxied_http_client().put_json(
- hs.config.report_stats_endpoint, stats
+ hs.config.metrics.report_stats_endpoint, stats
)
except Exception as e:
logger.warning("Error reporting stats: %s", e)
@@ -188,7 +190,7 @@ def start_phone_stats_home(hs):
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
- if hs.config.report_stats:
+ if hs.config.metrics.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 2cc242782a..d974a1a2a8 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -200,11 +200,7 @@ class Config:
@classmethod
def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path)
- try:
- os.makedirs(dir_path)
- except OSError as e:
- if e.errno != errno.EEXIST:
- raise
+ os.makedirs(dir_path, exist_ok=True)
if not os.path.isdir(dir_path):
raise ConfigError("%s is not a directory" % (dir_path,))
return dir_path
@@ -693,8 +689,7 @@ class RootConfig:
open_private_ports=config_args.open_private_ports,
)
- if not path_exists(config_dir_path):
- os.makedirs(config_dir_path)
+ os.makedirs(config_dir_path, exist_ok=True)
with open(config_path, "w") as config_file:
config_file.write(config_str)
config_file.write("\n\n# vim:ft=yaml")
diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index b05a9bd97f..ecc43b08b9 100644
--- a/synapse/config/consent.py
+++ b/synapse/config/consent.py
@@ -13,6 +13,7 @@
# limitations under the License.
from os import path
+from typing import Optional
from synapse.config import ConfigError
@@ -78,8 +79,8 @@ class ConsentConfig(Config):
def __init__(self, *args):
super().__init__(*args)
- self.user_consent_version = None
- self.user_consent_template_dir = None
+ self.user_consent_version: Optional[str] = None
+ self.user_consent_template_dir: Optional[str] = None
self.user_consent_server_notice_content = None
self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error = None
@@ -94,7 +95,9 @@ class ConsentConfig(Config):
return
self.user_consent_version = str(consent_config["version"])
self.user_consent_template_dir = self.abspath(consent_config["template_dir"])
- if not path.isdir(self.user_consent_template_dir):
+ if not isinstance(self.user_consent_template_dir, str) or not path.isdir(
+ self.user_consent_template_dir
+ ):
raise ConfigError(
"Could not find template directory '%s'"
% (self.user_consent_template_dir,)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index aca9d467e6..0a08231e5a 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -322,7 +322,9 @@ def setup_logging(
"""
log_config_path = (
- config.worker_log_config if use_worker_options else config.log_config
+ config.worker.worker_log_config
+ if use_worker_options
+ else config.logging.log_config
)
# Perform one-time logging configuration.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7b9109a592..ad8715da29 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1447,7 +1447,7 @@ def read_gc_thresholds(thresholds):
return None
try:
assert len(thresholds) == 3
- return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2]))
+ return int(thresholds[0]), int(thresholds[1]), int(thresholds[2])
except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index d310976fe3..2a6110eb10 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -74,8 +74,8 @@ class ServerContextFactory(ContextFactory):
context.set_options(
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
)
- context.use_certificate_chain_file(config.tls_certificate_file)
- context.use_privatekey(config.tls_private_key)
+ context.use_certificate_chain_file(config.tls.tls_certificate_file)
+ context.use_privatekey(config.tls.tls_private_key)
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list(
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f8d898c3b1..5ba01eeef9 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -80,9 +80,7 @@ class EventContext:
(type, state_key) -> event_id
- FIXME: what is this for an outlier? it seems ill-defined. It seems like
- it could be either {}, or the state we were given by the remote
- server, depending on $THINGS
+ For an outlier, this is {}
Note that this is a private attribute: it should be accessed via
``get_current_state_ids``. _AsyncEventContext impl calculates this
@@ -96,7 +94,7 @@ class EventContext:
(type, state_key) -> event_id
- FIXME: again, what is this for an outlier?
+ For an outlier, this is {}
As with _current_state_ids, this is a private attribute. It should be
accessed via get_prev_state_ids.
@@ -130,6 +128,14 @@ class EventContext:
delta_ids=delta_ids,
)
+ @staticmethod
+ def for_outlier():
+ """Return an EventContext instance suitable for persisting an outlier event"""
+ return EventContext(
+ current_state_ids={},
+ prev_state_ids={},
+ )
+
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 57f1d53fa8..c389f70b8d 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -46,6 +46,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
]
USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[
+ [str, List[str], List[Dict[str, str]]], Awaitable[bool]
+]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
@@ -78,7 +81,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
"""
spam_checkers: List[Any] = []
api = hs.get_module_api()
- for module, config in hs.config.spam_checkers:
+ for module, config in hs.config.spamchecker.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
@@ -164,6 +167,9 @@ class SpamChecker:
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
+ self._user_may_create_room_with_invites_callbacks: List[
+ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+ ] = []
self._user_may_create_room_alias_callbacks: List[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = []
@@ -183,6 +189,9 @@ class SpamChecker:
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+ user_may_create_room_with_invites: Optional[
+ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+ ] = None,
user_may_create_room_alias: Optional[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = None,
@@ -203,6 +212,11 @@ class SpamChecker:
if user_may_create_room is not None:
self._user_may_create_room_callbacks.append(user_may_create_room)
+ if user_may_create_room_with_invites is not None:
+ self._user_may_create_room_with_invites_callbacks.append(
+ user_may_create_room_with_invites,
+ )
+
if user_may_create_room_alias is not None:
self._user_may_create_room_alias_callbacks.append(
user_may_create_room_alias,
@@ -283,6 +297,34 @@ class SpamChecker:
return True
+ async def user_may_create_room_with_invites(
+ self,
+ userid: str,
+ invites: List[str],
+ threepid_invites: List[Dict[str, str]],
+ ) -> bool:
+ """Checks if a given user may create a room with invites
+
+ If this method returns false, the creation request will be rejected.
+
+ Args:
+ userid: The ID of the user attempting to create a room
+ invites: The IDs of the Matrix users to be invited if the room creation is
+ allowed.
+ threepid_invites: The threepids to be invited if the room creation is allowed,
+ as a dict including a "medium" key indicating the threepid's medium (e.g.
+ "email") and an "address" key indicating the threepid's address (e.g.
+ "alice@example.com")
+
+ Returns:
+ True if the user may create the room, otherwise False
+ """
+ for callback in self._user_may_create_room_with_invites_callbacks:
+ if await callback(userid, invites, threepid_invites) is False:
+ return False
+
+ return True
+
async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> bool:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 7a6eb3e516..d94b1bb4d2 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -42,10 +42,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
- if hs.config.third_party_event_rules is None:
+ if hs.config.thirdpartyrules.third_party_event_rules is None:
return
- module, config = hs.config.third_party_event_rules
+ module, config = hs.config.thirdpartyrules.third_party_event_rules
api = hs.get_module_api()
third_party_rules = module(config=config, module_api=api)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 1416abd0fb..584836c04a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -501,8 +501,6 @@ class FederationClient(FederationBase):
destination, auth_chain, outlier=True, room_version=room_version
)
- signed_auth.sort(key=lambda e: e.depth)
-
return signed_auth
def _is_unknown_endpoint(
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index c11d1f6d31..afe35e72b6 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -560,7 +560,7 @@ class PerDestinationQueue:
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
- return (edus, now_stream_id)
+ return edus, now_stream_id
async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
last_device_stream_id = self._last_device_stream_id
@@ -593,7 +593,7 @@ class PerDestinationQueue:
stream_id,
)
- return (edus, stream_id)
+ return edus, stream_id
def _start_catching_up(self) -> None:
"""
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 624c859f1e..cef65929c5 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -49,7 +49,9 @@ class Authenticator:
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
- self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.federation_domain_whitelist = (
+ hs.config.federation.federation_domain_whitelist
+ )
self.notifier = hs.get_notifier()
self.replication_client = None
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index d6b75ac27f..449bbc7004 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -847,16 +847,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
UserID.from_string(requester_user_id)
)
if not is_admin:
- if not self.hs.config.enable_group_creation:
+ if not self.hs.config.groups.enable_group_creation:
raise SynapseError(
403, "Only a server admin can create groups on this server"
)
localpart = group_id_obj.localpart
- if not localpart.startswith(self.hs.config.group_creation_prefix):
+ if not localpart.startswith(self.hs.config.groups.group_creation_prefix):
raise SynapseError(
400,
"Can only create groups with prefix %r on this server"
- % (self.hs.config.group_creation_prefix,),
+ % (self.hs.config.groups.group_creation_prefix,),
)
profile = content.get("profile", {})
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 4724565ba5..5a5f124ddf 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -47,7 +47,7 @@ class AccountValidityHandler:
self.send_email_handler = self.hs.get_send_email_handler()
self.clock = self.hs.get_clock()
- self._app_name = self.hs.config.email_app_name
+ self._app_name = self.hs.config.email.email_app_name
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index b7213b67a5..163278708c 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -52,7 +52,7 @@ class ApplicationServicesHandler:
self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
self.clock = hs.get_clock()
- self.notify_appservices = hs.config.notify_appservices
+ self.notify_appservices = hs.config.appservice.notify_appservices
self.event_sources = hs.get_event_sources()
self.current_max = 0
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bcd4249e09..a8c717efd5 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -210,15 +210,15 @@ class AuthHandler(BaseHandler):
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
- for module, config in hs.config.password_providers
+ for module, config in hs.config.authproviders.password_providers
]
logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
- self._password_enabled = hs.config.password_enabled
- self._password_localdb_enabled = hs.config.password_localdb_enabled
+ self._password_enabled = hs.config.auth.password_enabled
+ self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = set()
@@ -250,7 +250,7 @@ class AuthHandler(BaseHandler):
)
# The number of seconds to keep a UI auth session active.
- self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+ self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
@@ -277,23 +277,25 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
- self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+ self._sso_redirect_confirm_template = (
+ hs.config.sso.sso_redirect_confirm_template
+ )
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
- self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+ self._sso_auth_confirm_template = hs.config.sso.sso_auth_confirm_template
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
- hs.config.sso_account_deactivated_template
+ hs.config.sso.sso_account_deactivated_template
)
self._server_name = hs.config.server.server_name
# cast to tuple for use with str.startswith
- self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+ self._whitelisted_sso_clients = tuple(hs.config.sso.sso_client_whitelist)
# A mapping of user ID to extra attributes to include in the login
# response.
@@ -739,19 +741,19 @@ class AuthHandler(BaseHandler):
return canonical_id
def _get_params_recaptcha(self) -> dict:
- return {"public_key": self.hs.config.recaptcha_public_key}
+ return {"public_key": self.hs.config.captcha.recaptcha_public_key}
def _get_params_terms(self) -> dict:
return {
"policies": {
"privacy_policy": {
- "version": self.hs.config.user_consent_version,
+ "version": self.hs.config.consent.user_consent_version,
"en": {
- "name": self.hs.config.user_consent_policy_name,
+ "name": self.hs.config.consent.user_consent_policy_name,
"url": "%s_matrix/consent?v=%s"
% (
self.hs.config.server.public_baseurl,
- self.hs.config.user_consent_version,
+ self.hs.config.consent.user_consent_version,
),
},
}
@@ -1016,7 +1018,7 @@ class AuthHandler(BaseHandler):
def can_change_password(self) -> bool:
"""Get whether users on this server are allowed to change or set a password.
- Both `config.password_enabled` and `config.password_localdb_enabled` must be true.
+ Both `config.auth.password_enabled` and `config.auth.password_localdb_enabled` must be true.
Note that any account (even SSO accounts) are allowed to add passwords if the above
is true.
@@ -1486,7 +1488,7 @@ class AuthHandler(BaseHandler):
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
- pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii")
@@ -1510,7 +1512,7 @@ class AuthHandler(BaseHandler):
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
- pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"),
checked_hash,
)
@@ -1802,7 +1804,7 @@ class MacaroonGenerator:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key,
+ key=self.hs.config.key.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index b0b188dc78..5d8f6c50a9 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -65,10 +65,10 @@ class CasHandler:
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
- self._cas_displayname_attribute = hs.config.cas_displayname_attribute
- self._cas_required_attributes = hs.config.cas_required_attributes
+ self._cas_server_url = hs.config.cas.cas_server_url
+ self._cas_service_url = hs.config.cas.cas_service_url
+ self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
+ self._cas_required_attributes = hs.config.cas.cas_required_attributes
self._http_client = hs.get_proxied_http_client()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index a03ff9842b..9ae5b7750e 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -255,13 +255,16 @@ class DeactivateAccountHandler(BaseHandler):
Args:
user_id: ID of user to be re-activated
"""
- # Add the user to the directory, if necessary.
user = UserID.from_string(user_id)
- profile = await self.store.get_profileinfo(user.localpart)
- await self.user_directory_handler.handle_local_profile_change(user_id, profile)
# Ensure the user is not marked as erased.
await self.store.mark_user_not_erased(user_id)
# Mark the user as active.
await self.store.set_user_deactivated_status(user_id, False)
+
+ # Add the user to the directory, if necessary. Note that
+ # this must be done after the user is re-activated, because
+ # deactivated users are excluded from the user directory.
+ profile = await self.store.get_profileinfo(user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(user_id, profile)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index d487fee627..5cfba3c817 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -48,7 +48,7 @@ class DirectoryHandler(BaseHandler):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.config = hs.config
- self.enable_room_list_search = hs.config.enable_room_list_search
+ self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.require_membership_for_aliases
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -143,7 +143,7 @@ class DirectoryHandler(BaseHandler):
):
raise AuthError(403, "This user is not permitted to create this alias")
- if not self.config.is_alias_creation_allowed(
+ if not self.config.roomdirectory.is_alias_creation_allowed(
user_id, room_id, room_alias_str
):
# Lets just return a generic message, as there may be all sorts of
@@ -459,7 +459,7 @@ class DirectoryHandler(BaseHandler):
if canonical_alias:
room_aliases.append(canonical_alias)
- if not self.config.is_publishing_room_allowed(
+ if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8e2cf3387a..b17ef2a9a1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -91,7 +91,7 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
self._replication = hs.get_replication_data_handler()
@@ -593,6 +593,13 @@ class FederationHandler(BaseHandler):
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
+ # Mark the knock as an outlier as we don't yet have the state at this point in
+ # the DAG.
+ event.internal_metadata.outlier = True
+
+ # ... but tell /sync to send it to clients anyway.
+ event.internal_metadata.out_of_band_membership = True
+
# Record the room ID and its version so that we have a record of the room
await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=event_format_version
@@ -617,7 +624,7 @@ class FederationHandler(BaseHandler):
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
- context = await self.state_handler.compute_event_context(event)
+ context = EventContext.for_outlier()
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -807,7 +814,7 @@ class FederationHandler(BaseHandler):
)
)
- context = await self.state_handler.compute_event_context(event)
+ context = EventContext.for_outlier()
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -836,7 +843,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
- context = await self.state_handler.compute_event_context(event)
+ context = EventContext.for_outlier()
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1108,8 +1115,7 @@ class FederationHandler(BaseHandler):
events_to_context = {}
for e in itertools.chain(auth_events, state):
e.internal_metadata.outlier = True
- ctx = await self.state_handler.compute_event_context(e)
- events_to_context[e.event_id] = ctx
+ events_to_context[e.event_id] = EventContext.for_outlier()
event_map = {
e.event_id: e for e in itertools.chain(auth_events, state, [event])
@@ -1363,7 +1369,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
EventValidator().validate_new(event, self.config)
- return (event, context)
+ return event, context
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 3b95beeb08..01fd841122 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -27,11 +27,8 @@ from typing import (
Tuple,
)
-import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
@@ -54,11 +51,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
-from synapse.logging.context import (
- make_deferred_yieldable,
- nested_logging_context,
- run_in_background,
-)
+from synapse.logging.context import nested_logging_context, run_in_background
from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
@@ -75,7 +68,11 @@ from synapse.types import (
UserID,
get_domain_from_id,
)
-from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.async_helpers import (
+ Linearizer,
+ concurrently_execute,
+ yieldable_gather_results,
+)
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -92,30 +89,6 @@ soft_failed_event_counter = Counter(
)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _NewEventInfo:
- """Holds information about a received event, ready for passing to _auth_and_persist_events
-
- Attributes:
- event: the received event
-
- claimed_auth_event_map: a map of (type, state_key) => event for the event's
- claimed auth_events.
-
- This can include events which have not yet been persisted, in the case that
- we are backfilling a batch of events.
-
- Note: May be incomplete: if we were unable to find all of the claimed auth
- events. Also, treat the contents with caution: the events might also have
- been rejected, might not yet have been authorized themselves, or they might
- be in the wrong room.
-
- """
-
- event: EventBase
- claimed_auth_event_map: StateMap[EventBase]
-
-
class FederationEventHandler:
"""Handles events that originated from federation.
@@ -1107,7 +1080,7 @@ class FederationEventHandler:
room_version = await self._store.get_room_version(room_id)
- event_map: Dict[str, EventBase] = {}
+ events: List[EventBase] = []
async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
@@ -1125,8 +1098,7 @@ class FederationEventHandler:
event_id,
)
return
-
- event_map[event.event_id] = event
+ events.append(event)
except Exception as e:
logger.warning(
@@ -1137,11 +1109,29 @@ class FederationEventHandler:
)
await concurrently_execute(get_event, event_ids, 5)
- logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids))
+ logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
+ await self._auth_and_persist_fetched_events(destination, room_id, events)
+
+ async def _auth_and_persist_fetched_events(
+ self, origin: str, room_id: str, events: Iterable[EventBase]
+ ) -> None:
+ """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
+
+ The events to be persisted must be outliers.
+
+ We first sort the events to make sure that we process each event's auth_events
+ before the event itself, and then auth and persist them.
+
+ Notifies about the events where appropriate.
+
+ Params:
+ origin: where the events came from
+ room_id: the room that the events are meant to be in (though this has
+ not yet been checked)
+ events: the events that have been fetched
+ """
+ event_map = {event.event_id: event for event in events}
- # we now need to auth the events in an order which ensures that each event's
- # auth_events are authed before the event itself.
- #
# XXX: it might be possible to kick this process off in parallel with fetching
# the events.
while event_map:
@@ -1168,22 +1158,18 @@ class FederationEventHandler:
"Persisting %i of %i remaining events", len(roots), len(event_map)
)
- await self._auth_and_persist_fetched_events(destination, room_id, roots)
+ await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
for ev in roots:
del event_map[ev.event_id]
- async def _auth_and_persist_fetched_events(
+ async def _auth_and_persist_fetched_events_inner(
self, origin: str, room_id: str, fetched_events: Collection[EventBase]
) -> None:
- """Persist the events fetched by _get_events_and_persist.
-
- The events should not depend on one another, e.g. this should be used to persist
- a bunch of outliers, but not a chunk of individual events that depend
- on each other for state calculations.
+ """Helper for _auth_and_persist_fetched_events
- We also assume that all of the auth events for all of the events have already
- been persisted.
+ Persists a batch of events where we have (theoretically) already persisted all
+ of their auth events.
Notifies about the events where appropriate.
@@ -1191,7 +1177,7 @@ class FederationEventHandler:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
- event_id: map from event_id -> event for the fetched events
+ fetched_events: the events to persist
"""
# get all the auth events for all the events in this batch. By now, they should
# have been persisted.
@@ -1203,47 +1189,36 @@ class FederationEventHandler:
allow_rejected=True,
)
- event_infos = []
- for event in fetched_events:
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id)
- if ae:
+ async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
+ with nested_logging_context(suffix=event.event_id):
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id)
+ if not ae:
+ logger.warning(
+ "Event %s relies on auth_event %s, which could not be found.",
+ event,
+ auth_event_id,
+ )
+ # the fact we can't find the auth event doesn't mean it doesn't
+ # exist, which means it is premature to reject `event`. Instead we
+ # just ignore it for now.
+ return None
auth[(ae.type, ae.state_key)] = ae
- else:
- logger.info("Missing auth event %s", auth_event_id)
-
- event_infos.append(_NewEventInfo(event, auth))
-
- if not event_infos:
- return
- async def prep(ev_info: _NewEventInfo) -> EventContext:
- event = ev_info.event
- with nested_logging_context(suffix=event.event_id):
- res = await self._state_handler.compute_event_context(event)
- res = await self._check_event_auth(
+ context = EventContext.for_outlier()
+ context = await self._check_event_auth(
origin,
event,
- res,
- claimed_auth_event_map=ev_info.claimed_auth_event_map,
+ context,
+ claimed_auth_event_map=auth,
)
- return res
+ return event, context
- contexts = await make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(prep, ev_info) for ev_info in event_infos],
- consumeErrors=True,
- )
- )
-
- await self.persist_events_and_notify(
- room_id,
- [
- (ev_info.event, context)
- for ev_info, context in zip(event_infos, contexts)
- ],
+ events_to_persist = (
+ x for x in await yieldable_gather_results(prep, fetched_events) if x
)
+ await self.persist_events_and_notify(room_id, tuple(events_to_persist))
async def _check_event_auth(
self,
@@ -1269,8 +1244,7 @@ class FederationEventHandler:
claimed_auth_event_map:
A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
+ Possibly including events that were rejected, or are in the wrong room.
Only populated when populating outliers.
@@ -1505,64 +1479,22 @@ class FederationEventHandler:
# If we don't have all the auth events, we need to get them.
logger.info("auth_events contains unknown events: %s", missing_auth)
try:
- try:
- remote_auth_chain = await self._federation_client.get_event_auth(
- origin, event.room_id, event.event_id
- )
- except RequestSendFailed as e1:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e1)
- return context, auth_events
-
- seen_remotes = await self._store.have_seen_events(
- event.room_id, [e.event_id for e in remote_auth_chain]
+ await self._get_remote_auth_chain_for_event(
+ origin, event.room_id, event.event_id
)
-
- for auth_event in remote_auth_chain:
- if auth_event.event_id in seen_remotes:
- continue
-
- if auth_event.event_id == event.event_id:
- continue
-
- try:
- auth_ids = auth_event.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in remote_auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- auth_event.internal_metadata.outlier = True
-
- logger.debug(
- "_check_event_auth %s missing_auth: %s",
- event.event_id,
- auth_event.event_id,
- )
- missing_auth_event_context = (
- await self._state_handler.compute_event_context(auth_event)
- )
-
- missing_auth_event_context = await self._check_event_auth(
- origin,
- auth_event,
- missing_auth_event_context,
- claimed_auth_event_map=auth,
- )
- await self.persist_events_and_notify(
- event.room_id, [(auth_event, missing_auth_event_context)]
- )
-
- if auth_event.event_id in event_auth_events:
- auth_events[
- (auth_event.type, auth_event.state_key)
- ] = auth_event
- except AuthError:
- pass
-
except Exception:
logger.exception("Failed to get auth chain")
+ else:
+ # load any auth events we might have persisted from the database. This
+ # has the side-effect of correctly setting the rejected_reason on them.
+ auth_events.update(
+ {
+ (ae.type, ae.state_key): ae
+ for ae in await self._store.get_events_as_list(
+ missing_auth, allow_rejected=True
+ )
+ }
+ )
if event.internal_metadata.is_outlier():
# XXX: given that, for an outlier, we'll be working with the
@@ -1636,6 +1568,45 @@ class FederationEventHandler:
return context, auth_events
+ async def _get_remote_auth_chain_for_event(
+ self, destination: str, room_id: str, event_id: str
+ ) -> None:
+ """If we are missing some of an event's auth events, attempt to request them
+
+ Args:
+ destination: where to fetch the auth tree from
+ room_id: the room in which we are lacking auth events
+ event_id: the event for which we are lacking auth events
+ """
+ try:
+ remote_event_map = {
+ e.event_id: e
+ for e in await self._federation_client.get_event_auth(
+ destination, room_id, event_id
+ )
+ }
+ except RequestSendFailed as e1:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to get event auth from remote: %s", e1)
+ return
+
+ logger.info("/event_auth returned %i events", len(remote_event_map))
+
+ # `event` may be returned, but we should not yet process it.
+ remote_event_map.pop(event_id, None)
+
+ # nor should we reprocess any events we have already seen.
+ seen_remotes = await self._store.have_seen_events(
+ room_id, remote_event_map.keys()
+ )
+ for s in seen_remotes:
+ remote_event_map.pop(s, None)
+
+ await self._auth_and_persist_fetched_events(
+ destination, room_id, remote_event_map.values()
+ )
+
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
) -> EventContext:
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 8b8f1f41ca..fe8a995892 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -62,7 +62,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
- self._web_client_location = hs.config.invite_client_location
+ self._web_client_location = hs.config.email.invite_client_location
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
@@ -419,7 +419,7 @@ class IdentityHandler(BaseHandler):
token_expires = (
self.hs.get_clock().time_msec()
- + self.hs.config.email_validation_token_lifetime
+ + self.hs.config.email.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
@@ -465,7 +465,7 @@ class IdentityHandler(BaseHandler):
if next_link:
params["next_link"] = next_link
- if self.hs.config.using_identity_server_from_trusted_list:
+ if self.hs.config.email.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
logger.warning(
'The config option "trust_identity_server_for_password_resets" '
@@ -518,7 +518,7 @@ class IdentityHandler(BaseHandler):
if next_link:
params["next_link"] = next_link
- if self.hs.config.using_identity_server_from_trusted_list:
+ if self.hs.config.email.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
logger.warning(
'The config option "trust_identity_server_for_password_resets" '
@@ -572,12 +572,12 @@ class IdentityHandler(BaseHandler):
validation_session = None
# Try to validate as email
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
- elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bf2763b0f3..fb3aa6a83d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -443,7 +443,7 @@ class EventCreationHandler:
)
self._block_events_without_consent_error = (
- self.config.block_events_without_consent_error
+ self.config.consent.block_events_without_consent_error
)
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
@@ -666,7 +666,7 @@ class EventCreationHandler:
self.validator.validate_new(event, self.config)
- return (event, context)
+ return event, context
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -692,10 +692,10 @@ class EventCreationHandler:
return False
async def _is_server_notices_room(self, room_id: str) -> bool:
- if self.config.server_notices_mxid is None:
+ if self.config.servernotices.server_notices_mxid is None:
return False
user_ids = await self.store.get_users_in_room(room_id)
- return self.config.server_notices_mxid in user_ids
+ return self.config.servernotices.server_notices_mxid in user_ids
async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy
@@ -731,8 +731,8 @@ class EventCreationHandler:
# exempt the system notices user
if (
- self.config.server_notices_mxid is not None
- and user_id == self.config.server_notices_mxid
+ self.config.servernotices.server_notices_mxid is not None
+ and user_id == self.config.servernotices.server_notices_mxid
):
return
@@ -744,7 +744,7 @@ class EventCreationHandler:
if u["appservice_id"] is not None:
# users registered by an appservice are exempt
return
- if u["consent_version"] == self.config.user_consent_version:
+ if u["consent_version"] == self.config.consent.user_consent_version:
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
@@ -1004,7 +1004,7 @@ class EventCreationHandler:
logger.debug("Created event %s", event.event_id)
- return (event, context)
+ return event, context
@measure_func("handle_new_client_event")
async def handle_new_client_event(
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index aed5a40a78..3665d91513 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
- self._callback_url: str = hs.config.oidc_callback_url
+ self._callback_url: str = hs.config.oidc.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index cd21efdcc6..eadd7ced09 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class PasswordPolicyHandler:
def __init__(self, hs: "HomeServer"):
- self.policy = hs.config.password_policy
- self.enabled = hs.config.password_policy_enabled
+ self.policy = hs.config.auth.password_policy
+ self.enabled = hs.config.auth.password_policy_enabled
# Regexps for the spec'd policy parameters.
self.regexp_digit = re.compile("[0-9]")
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index f06070bfcf..b23a1541bc 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -309,7 +309,7 @@ class ProfileHandler(BaseHandler):
async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
- if not self.hs.config.allow_profile_lookup_over_federation:
+ if not self.hs.config.federation.allow_profile_lookup_over_federation:
raise SynapseError(
403,
"Profile lookup over federation is disabled on this homeserver",
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 5881f09ebd..f21f33ada2 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -238,7 +238,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
if self.config.experimental.msc2285_enabled:
events = ReceiptEventSource.filter_out_hidden(events, user.to_string())
- return (events, to_key)
+ return events, to_key
async def get_new_events_as(
self, from_key: int, service: ApplicationService
@@ -270,7 +270,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
events.append(event)
- return (events, to_key)
+ return events, to_key
def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1c195c65db..4f99f137a2 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -97,7 +97,8 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._account_validity_handler = hs.get_account_validity_handler()
- self._server_notices_mxid = hs.config.server_notices_mxid
+ self._user_consent_version = self.hs.config.consent.user_consent_version
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
@@ -339,7 +340,7 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
- if not self.hs.config.user_consent_at_registration:
+ if not self.hs.config.consent.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
"Skipping auto-join for %s because auto-join for guests is disabled",
@@ -864,7 +865,9 @@ class RegistrationHandler(BaseHandler):
await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
- await self._on_user_consented(user_id, self.hs.config.user_consent_version)
+ # The terms type should only exist if consent is enabled.
+ assert self._user_consent_version is not None
+ await self._on_user_consented(user_id, self._user_consent_version)
async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
@@ -910,8 +913,8 @@ class RegistrationHandler(BaseHandler):
# getting mail spam where they weren't before if email
# notifs are set up on a homeserver)
if (
- self.hs.config.email_enable_notifs
- and self.hs.config.email_notif_for_new_users
+ self.hs.config.email.email_enable_notifs
+ and self.hs.config.email.email_notif_for_new_users
and token
):
# Pull the ID of the access token back out of the db
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 287ea2fd06..8fede5e935 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -126,7 +126,7 @@ class RoomCreationHandler(BaseHandler):
for preset_name, preset_config in self._presets_dict.items():
encrypted = (
preset_name
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
preset_config["encrypted"] = encrypted
@@ -141,7 +141,7 @@ class RoomCreationHandler(BaseHandler):
self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
)
- self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -649,8 +649,16 @@ class RoomCreationHandler(BaseHandler):
requester, config, is_requester_admin=is_requester_admin
)
- if not is_requester_admin and not await self.spam_checker.user_may_create_room(
- user_id
+ invite_3pid_list = config.get("invite_3pid", [])
+ invite_list = config.get("invite", [])
+
+ if not is_requester_admin and not (
+ await self.spam_checker.user_may_create_room(user_id)
+ and await self.spam_checker.user_may_create_room_with_invites(
+ user_id,
+ invite_list,
+ invite_3pid_list,
+ )
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -684,8 +692,6 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
- invite_3pid_list = config.get("invite_3pid", [])
- invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -757,7 +763,9 @@ class RoomCreationHandler(BaseHandler):
)
if is_public:
- if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias):
+ if not self.config.roomdirectory.is_publishing_room_allowed(
+ user_id, room_id, room_alias
+ ):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
@@ -1235,7 +1243,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
else:
end_key = to_key
- return (events, end_key)
+ return events, end_key
def get_current_key(self) -> RoomStreamToken:
return self.store.get_room_max_token()
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index c83ff585e3..c3d4199ed1 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -52,7 +52,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.enable_room_list_search = hs.config.enable_room_list_search
+ self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
] = ResponseCache(hs.get_clock(), "room_list")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4969ee395b..19b4e7c19c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -89,7 +89,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
- self._server_notices_mxid = self.config.server_notices_mxid
+ self._server_notices_mxid = self.config.servernotices.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4e28fb9685..fb26ee7ad7 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -1179,4 +1179,4 @@ def _child_events_comparison_key(
order = None
# Items without an order come last.
- return (order is None, order, child.origin_server_ts, child.room_id)
+ return order is None, order, child.origin_server_ts, child.room_id
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 185befbe9f..2fed9f377a 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -54,19 +54,18 @@ class Saml2SessionData:
class SamlHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._saml_idp_entityid = hs.config.saml2_idp_entityid
+ self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
+ self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid
- self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ self._saml2_session_lifetime = hs.config.saml2.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
- hs.config.saml2_grandfathered_mxid_source_attribute
+ hs.config.saml2.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
- self._error_template = hs.config.sso_error_template
# plugin to do custom mapping from saml response to mxid
- self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
- hs.config.saml2_user_mapping_provider_config,
+ self._user_mapping_provider = hs.config.saml2.saml2_user_mapping_provider_class(
+ hs.config.saml2.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
@@ -411,7 +410,7 @@ class DefaultSamlMappingProvider:
self._mxid_mapper = parsed_config.mxid_mapper
self._grandfathered_mxid_source_attribute = (
- module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+ module_api._hs.config.saml2.saml2_grandfathered_mxid_source_attribute
)
def get_remote_user_id(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e044251a13..49fde01cf0 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -184,15 +184,17 @@ class SsoHandler:
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
- self._error_template = hs.config.sso_error_template
- self._bad_user_template = hs.config.sso_auth_bad_user_template
+ self._error_template = hs.config.sso.sso_error_template
+ self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
- self._sso_auth_success_template = hs.config.sso_auth_success_template
+ self._sso_auth_success_template = hs.config.sso.sso_auth_success_template
- self._sso_update_profile_information = hs.config.sso_update_profile_information
+ self._sso_update_profile_information = (
+ hs.config.sso.sso_update_profile_information
+ )
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 9fc53333fc..bd3e6f2ec7 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -46,7 +46,7 @@ class StatsHandler:
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.stats_enabled = hs.config.stats_enabled
+ self.stats_enabled = hs.config.stats.stats_enabled
# The current position in the current_state_delta stream
self.pos: Optional[int] = None
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 9326330c90..d10e9b8ec4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -483,7 +483,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
events.append(self._make_event_for(room_id))
- return (events, handler._latest_room_serial)
+ return events, handler._latest_room_serial
async def get_new_events(
self,
@@ -507,7 +507,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
events.append(self._make_event_for(room_id))
- return (events, handler._latest_room_serial)
+ return events, handler._latest_room_serial
def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index ea9325e96a..8f5d465fa1 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -82,10 +82,10 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self._enabled = bool(hs.config.recaptcha_private_key)
+ self._enabled = bool(hs.config.captcha.recaptcha_private_key)
self._http_client = hs.get_proxied_http_client()
- self._url = hs.config.recaptcha_siteverify_api
- self._secret = hs.config.recaptcha_private_key
+ self._url = hs.config.captcha.recaptcha_siteverify_api
+ self._secret = hs.config.captcha.recaptcha_private_key
def is_enabled(self) -> bool:
return self._enabled
@@ -161,12 +161,17 @@ class _BaseThreepidAuthChecker:
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if (
+ self.hs.config.email.threepid_behaviour_email
+ == ThreepidBehaviour.REMOTE
+ ):
assert self.hs.config.account_threepid_delegate_email
threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
- elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ elif (
+ self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ ):
threepid = None
row = await self.store.get_threepid_validation_session(
medium,
@@ -218,7 +223,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
- return self.hs.config.threepid_behaviour_email in (
+ return self.hs.config.email.threepid_behaviour_email in (
ThreepidBehaviour.REMOTE,
ThreepidBehaviour.LOCAL,
)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8dc46d7674..b91e7cb501 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -61,7 +61,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
- self.search_all_users = hs.config.user_directory_search_all_users
+ self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
self.pos: Optional[int] = None
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ef10ec0937..cdc36b8d25 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -465,8 +465,9 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if (
- self.hs.config.federation_domain_whitelist is not None
- and request.destination not in self.hs.config.federation_domain_whitelist
+ self.hs.config.federation.federation_domain_whitelist is not None
+ and request.destination
+ not in self.hs.config.federation.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@@ -1186,7 +1187,7 @@ class MatrixFederationHttpClient:
request.method,
request.uri.decode("ascii"),
)
- return (length, headers)
+ return length, headers
def _flatten_response_never_received(e):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b79fa722e9..1a50305dcf 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,7 +21,6 @@ import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
-from io import BytesIO
from typing import (
Any,
Awaitable,
@@ -37,7 +36,7 @@ from typing import (
)
import jinja2
-from canonicaljson import iterencode_canonical_json
+from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer
@@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
-from twisted.web.static import File, NoRangeStaticProducer
+from twisted.web.static import File
from twisted.web.util import redirectTo
from synapse.api.errors import (
@@ -56,10 +55,11 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.context import preserve_fn
+from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
+from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__)
@@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
- request: Request,
+ request: SynapseRequest,
code: int,
response_object: Any,
):
@@ -620,16 +620,15 @@ class _ByteProducer:
self._request = None
-def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+def _encode_json_bytes(json_object: Any) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
- for chunk in json_encoder.iterencode(json_object):
- yield chunk.encode("utf-8")
+ return json_encoder.encode(json_object).encode("utf-8")
def respond_with_json(
- request: Request,
+ request: SynapseRequest,
code: int,
json_object: Any,
send_cors: bool = False,
@@ -659,7 +658,7 @@ def respond_with_json(
return None
if canonical_json:
- encoder = iterencode_canonical_json
+ encoder = encode_canonical_json
else:
encoder = _encode_json_bytes
@@ -670,7 +669,9 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
- _ByteProducer(request, encoder(json_object))
+ run_in_background(
+ _async_write_json_to_request_in_thread, request, encoder, json_object
+ )
return NOT_DONE_YET
@@ -706,15 +707,56 @@ def respond_with_json_bytes(
if send_cors:
set_cors_headers(request)
- # note that this is zero-copy (the bytesio shares a copy-on-write buffer with
- # the original `bytes`).
- bytes_io = BytesIO(json_bytes)
-
- producer = NoRangeStaticProducer(request, bytes_io)
- producer.start()
+ _write_bytes_to_request(request, json_bytes)
return NOT_DONE_YET
+async def _async_write_json_to_request_in_thread(
+ request: SynapseRequest,
+ json_encoder: Callable[[Any], bytes],
+ json_object: Any,
+):
+ """Encodes the given JSON object on a thread and then writes it to the
+ request.
+
+ This is done so that encoding large JSON objects doesn't block the reactor
+ thread.
+
+ Note: We don't use JsonEncoder.iterencode here as that falls back to the
+ Python implementation (rather than the C backend), which is *much* more
+ expensive.
+ """
+
+ json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
+
+ _write_bytes_to_request(request, json_str)
+
+
+def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
+ """Writes the bytes to the request using an appropriate producer.
+
+ Note: This should be used instead of `Request.write` to correctly handle
+ large response bodies.
+ """
+
+ # The problem with dumping all of the response into the `Request` object at
+ # once (via `Request.write`) is that doing so starts the timeout for the
+ # next request to be received: so if it takes longer than 60s to stream back
+ # the response to the client, the client never gets it.
+ #
+ # The correct solution is to use a Producer; then the timeout is only
+ # started once all of the content is sent over the TCP connection.
+
+ # To make sure we don't write all of the bytes at once we split it up into
+ # chunks.
+ chunk_size = 4096
+ bytes_generator = chunk_seq(bytes_to_write, chunk_size)
+
+ # We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
+ # unit tests can't cope with being given a pull producer.
+ _ByteProducer(request, bytes_generator)
+
+
def set_cors_headers(request: Request):
"""Set the CORS headers so that javascript running in a web browsers can
use this API
diff --git a/synapse/http/site.py b/synapse/http/site.py
index dd4c749e16..755ad56637 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,14 @@
import contextlib
import logging
import time
-from typing import Optional, Tuple, Union
+from typing import Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
+from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
- def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
- Request.__init__(self, channel, *args, **kw)
+ def __init__(
+ self,
+ channel: HTTPChannel,
+ site: "SynapseSite",
+ *args,
+ max_request_body_size: int = 1024,
+ **kw,
+ ):
+ super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
- self.site: SynapseSite = channel.site
+ self.synapse_site = site
+ self.reactor = site.reactor
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -97,7 +106,7 @@ class SynapseRequest(Request):
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"),
- self.site.site_tag,
+ self.synapse_site.site_tag,
)
def handleContentChunk(self, data: bytes) -> None:
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
- site_tag=self.site.site_tag,
+ site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
)
# override the Server header which is set by twisted
- self.setHeader("Server", self.site.server_version_string)
+ self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
- def processing(self):
+ def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@@ -346,10 +355,10 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method()
)
- self.site.access_logger.debug(
+ self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
- self.site.site_tag,
+ self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
)
@@ -388,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity:
requester = f"{authenticated_entity}|{requester}"
- self.site.access_logger.log(
+ self.synapse_site.access_logger.log(
log_level,
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
- self.site.site_tag,
+ self.synapse_site.site_tag,
requester,
processing_time,
response_send_time,
@@ -522,7 +531,7 @@ class SynapseSite(Site):
site_tag: str,
config: ListenerConfig,
resource: IResource,
- server_version_string,
+ server_version_string: str,
max_request_body_size: int,
reactor: IReactorTime,
):
@@ -542,6 +551,7 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
+ self.reactor = reactor
assert config.http_options is not None
proxied = config.http_options.x_forwarded
@@ -550,6 +560,7 @@ class SynapseSite(Site):
def request_factory(channel, queued: bool) -> Request:
return request_class(
channel,
+ self,
max_request_body_size=max_request_body_size,
queued=queued,
)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c6c4d3bd29..03d2dd94f6 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -363,7 +363,7 @@ def noop_context_manager(*args, **kwargs):
def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
- if not hs.config.opentracer_enabled:
+ if not hs.config.tracing.opentracer_enabled:
# We don't have a tracer
opentracing = None
return
@@ -377,12 +377,12 @@ def init_tracer(hs: "HomeServer"):
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
- set_homeserver_whitelist(hs.config.opentracer_whitelist)
+ set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist)
from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
config = JaegerConfig(
- config=hs.config.jaeger_config,
+ config=hs.config.tracing.jaeger_config,
service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
scope_manager=LogContextScopeManager(hs.config),
metrics_factory=PrometheusMetricsFactory(),
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 3196c2bec6..8ae21bc43c 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,8 +24,10 @@ from typing import (
List,
Optional,
Tuple,
+ Union,
)
+import attr
import jinja2
from twisted.internet import defer
@@ -46,7 +48,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester
+from synapse.types import (
+ DomainSpecificString,
+ JsonDict,
+ Requester,
+ UserID,
+ UserInfo,
+ create_requester,
+)
from synapse.util import Clock
from synapse.util.caches.descriptors import cached
@@ -79,6 +88,18 @@ __all__ = [
logger = logging.getLogger(__name__)
+@attr.s(auto_attribs=True)
+class UserIpAndAgent:
+ """
+ An IP address and user agent used by a user to connect to this homeserver.
+ """
+
+ ip: str
+ user_agent: str
+ # The time at which this user agent/ip was last seen.
+ last_seen: int
+
+
class ModuleApi:
"""A proxy object that gets passed to various plugin modules so they
can register new users etc if necessary.
@@ -98,14 +119,16 @@ class ModuleApi:
self.custom_template_dir = hs.config.server.custom_template_directory
try:
- app_name = self._hs.config.email_app_name
+ app_name = self._hs.config.email.email_app_name
- self._from_string = self._hs.config.email_notif_from % {"app": app_name}
+ self._from_string = self._hs.config.email.email_notif_from % {
+ "app": app_name
+ }
except (KeyError, TypeError):
# If substitution failed (which can happen if the string contains
# placeholders other than just "app", or if the type of the placeholder is
# not a string), fall back to the bare strings.
- self._from_string = self._hs.config.email_notif_from
+ self._from_string = self._hs.config.email.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -700,6 +723,65 @@ class ModuleApi:
(td for td in (self.custom_template_dir, custom_template_directory) if td),
)
+ def is_mine(self, id: Union[str, DomainSpecificString]) -> bool:
+ """
+ Checks whether an ID (user id, room, ...) comes from this homeserver.
+
+ Args:
+ id: any Matrix id (e.g. user id, room id, ...), either as a raw id,
+ e.g. string "@user:example.com" or as a parsed UserID, RoomID, ...
+ Returns:
+ True if id comes from this homeserver, False otherwise.
+
+ Added in Synapse v1.44.0.
+ """
+ if isinstance(id, DomainSpecificString):
+ return self._hs.is_mine(id)
+ else:
+ return self._hs.is_mine_id(id)
+
+ async def get_user_ip_and_agents(
+ self, user_id: str, since_ts: int = 0
+ ) -> List[UserIpAndAgent]:
+ """
+ Return the list of user IPs and agents for a user.
+
+ Args:
+ user_id: the id of a user, local or remote
+ since_ts: a timestamp in seconds since the epoch,
+ or the epoch itself if not specified.
+ Returns:
+ The list of all UserIpAndAgent that the user has
+ used to connect to this homeserver since `since_ts`.
+ If the user is remote, this list is empty.
+
+ Added in Synapse v1.44.0.
+ """
+ # Don't hit the db if this is not a local user.
+ is_mine = False
+ try:
+ # Let's be defensive against ill-formed strings.
+ if self.is_mine(user_id):
+ is_mine = True
+ except Exception:
+ pass
+
+ if is_mine:
+ raw_data = await self._store.get_user_ip_and_agents(
+ UserID.from_string(user_id), since_ts
+ )
+ # Sanitize some of the data. We don't want to return tokens.
+ return [
+ UserIpAndAgent(
+ ip=str(data["ip"]),
+ user_agent=str(data["user_agent"]),
+ last_seen=int(data["last_seen"]),
+ )
+ for data in raw_data
+ ]
+ else:
+ return []
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index e08e125cb8..cf5abdfbda 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -184,7 +184,7 @@ class EmailPusher(Pusher):
should_notify_at = max(notif_ready_at, room_ready_at)
- if should_notify_at < self.clock.time_msec():
+ if should_notify_at <= self.clock.time_msec():
# one of our notifications is ready for sending, so we send
# *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 33430b167c..d88081e96d 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -73,7 +73,9 @@ class HttpPusher(Pusher):
self.failing_since = pusher_config.failing_since
self.timed_call: Optional[IDelayedCall] = None
self._is_processing = False
- self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
+ self._group_unread_count_by_room = (
+ hs.config.push.push_group_unread_count_by_room
+ )
self._pusherpool = hs.get_pusherpool()
self.data = pusher_config.data
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 29ed346d37..b57e094091 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -77,4 +77,4 @@ class PusherFactory:
if isinstance(brand, str):
return brand
- return self.config.email_app_name
+ return self.config.email.email_app_name
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 25589b0042..f1b78d09f9 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -168,8 +168,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
client = hs.get_simple_http_client()
local_instance_name = hs.get_instance_name()
- master_host = hs.config.worker_replication_host
- master_port = hs.config.worker_replication_http_port
+ master_host = hs.config.worker.worker_replication_host
+ master_port = hs.config.worker.worker_replication_http_port
instance_map = hs.config.worker.instance_map
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 509ed7fb13..1438a82b60 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -322,8 +322,8 @@ class ReplicationCommandHandler:
else:
client_name = hs.get_instance_name()
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_port
+ host = hs.config.worker.worker_replication_host
+ port = hs.config.worker.worker_replication_port
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index a03774c98a..e1506deb2b 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -267,7 +267,7 @@ def register_servlets_for_client_rest_resource(
# Load the media repo ones if we're using them. Otherwise load the servlets which
# don't need a media repo (typically readonly admin APIs).
- if hs.config.can_load_media_repo:
+ if hs.config.media.can_load_media_repo:
register_servlets_for_media_repo(hs, http_server)
else:
ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 5a1c929d85..aba48f6e7b 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -113,7 +113,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
# A string of all the characters allowed to be in a registration_token
- self.allowed_chars = string.ascii_letters + string.digits + "-_"
+ self.allowed_chars = string.ascii_letters + string.digits + "._~-"
self.allowed_chars_set = set(self.allowed_chars)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 8f781f745f..a4823ca6e7 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -213,7 +213,7 @@ class RoomRestServlet(RestServlet):
members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
- return (200, ret)
+ return 200, ret
async def on_DELETE(
self, request: SynapseRequest, room_id: str
@@ -668,4 +668,4 @@ async def _delete_room(
if purge:
await pagination_handler.purge_room(room_id, force=force_purge)
- return (200, ret)
+ return 200, ret
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 681e491826..46bfec4623 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -368,8 +368,8 @@ class UserRestServletV2(RestServlet):
user_id, medium, address, current_time
)
if (
- self.hs.config.email_enable_notifs
- and self.hs.config.email_notif_for_new_users
+ self.hs.config.email.email_enable_notifs
+ and self.hs.config.email.email_notif_for_new_users
):
await self.pusher_pool.add_pusher(
user_id=user_id,
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index aefaaa8ae8..6a7608d60b 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -64,17 +64,17 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.config = hs.config
self.identity_handler = hs.get_identity_handler()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
- app_name=self.config.email_app_name,
- template_html=self.config.email_password_reset_template_html,
- template_text=self.config.email_password_reset_template_text,
+ app_name=self.config.email.email_app_name,
+ template_html=self.config.email.email_password_reset_template_html,
+ template_text=self.config.email.email_password_reset_template_text,
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"User password resets have been disabled due to lack of email config"
)
@@ -129,7 +129,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
@@ -349,17 +349,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
- app_name=self.config.email_app_name,
- template_html=self.config.email_add_threepid_template_html,
- template_text=self.config.email_add_threepid_template_text,
+ app_name=self.config.email.email_app_name,
+ template_html=self.config.email.email_add_threepid_template_html,
+ template_text=self.config.email.email_add_threepid_template_text,
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
@@ -413,7 +413,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
@@ -534,21 +534,21 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
- self.config.email_add_threepid_template_failure_html
+ self.config.email.email_add_threepid_template_failure_html
)
async def on_GET(self, request: Request) -> None:
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
400, "Adding an email to your account is disabled on this server"
)
- elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
raise SynapseError(
400,
"This homeserver is not validating threepids. Use an identity server "
@@ -575,7 +575,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_add_threepid_template_success_html_content
+ html = self.config.email.email_add_threepid_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 7bb7801472..282861fae2 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -47,7 +47,7 @@ class AuthRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- self.recaptcha_template = hs.config.recaptcha_template
+ self.recaptcha_template = hs.config.captcha.recaptcha_template
self.terms_template = hs.config.terms_template
self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
@@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet):
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- sitekey=self.hs.config.recaptcha_public_key,
+ sitekey=self.hs.config.captcha.recaptcha_public_key,
)
elif stagetype == LoginType.TERMS:
html = self.terms_template.render(
@@ -70,7 +70,7 @@ class AuthRestServlet(RestServlet):
terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.server.public_baseurl,
- self.hs.config.user_consent_version,
+ self.hs.config.consent.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
@@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet):
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- sitekey=self.hs.config.recaptcha_public_key,
+ sitekey=self.hs.config.captcha.recaptcha_public_key,
error=e.msg,
)
else:
@@ -139,7 +139,7 @@ class AuthRestServlet(RestServlet):
terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.server.public_baseurl,
- self.hs.config.user_consent_version,
+ self.hs.config.consent.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 25bc3c8f47..8566dc5cb5 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -211,7 +211,7 @@ class DehydratedDeviceServlet(RestServlet):
if dehydrated_device is not None:
(device_id, device_data) = dehydrated_device
result = {"device_id": device_id, "device_data": device_data}
- return (200, result)
+ return 200, result
else:
raise errors.NotFoundError("No dehydrated device available")
@@ -293,7 +293,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
submission["device_id"],
)
- return (200, result)
+ return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index a6ede7e2f3..fa5c173f4b 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -69,16 +69,16 @@ class LoginRestServlet(RestServlet):
self.hs = hs
# JWT configuration variables.
- self.jwt_enabled = hs.config.jwt_enabled
- self.jwt_secret = hs.config.jwt_secret
- self.jwt_algorithm = hs.config.jwt_algorithm
- self.jwt_issuer = hs.config.jwt_issuer
- self.jwt_audiences = hs.config.jwt_audiences
+ self.jwt_enabled = hs.config.jwt.jwt_enabled
+ self.jwt_secret = hs.config.jwt.jwt_secret
+ self.jwt_algorithm = hs.config.jwt.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt.jwt_issuer
+ self.jwt_audiences = hs.config.jwt.jwt_audiences
# SSO configuration.
- self.saml2_enabled = hs.config.saml2_enabled
- self.cas_enabled = hs.config.cas_enabled
- self.oidc_enabled = hs.config.oidc_enabled
+ self.saml2_enabled = hs.config.saml2.saml2_enabled
+ self.cas_enabled = hs.config.cas.cas_enabled
+ self.oidc_enabled = hs.config.oidc.oidc_enabled
self._msc2918_enabled = hs.config.access_token_lifetime is not None
self.auth = hs.get_auth()
@@ -559,7 +559,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
- if hs.config.cas_enabled:
+ if hs.config.cas.cas_enabled:
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py
index 6d64efb165..9f1908004b 100644
--- a/synapse/rest/client/password_policy.py
+++ b/synapse/rest/client/password_policy.py
@@ -35,12 +35,12 @@ class PasswordPolicyServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.policy = hs.config.password_policy
- self.enabled = hs.config.password_policy_enabled
+ self.policy = hs.config.auth.password_policy
+ self.enabled = hs.config.auth.password_policy_enabled
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy:
- return (200, {})
+ return 200, {}
policy = {}
@@ -54,7 +54,7 @@ class PasswordPolicyServlet(RestServlet):
if param in self.policy:
policy["m.%s" % param] = self.policy[param]
- return (200, policy)
+ return 200, policy
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index abe4d7e205..48b0062cf4 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -75,17 +75,19 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.config = hs.config
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
- app_name=self.config.email_app_name,
- template_html=self.config.email_registration_template_html,
- template_text=self.config.email_registration_template_text,
+ app_name=self.config.email.email_app_name,
+ template_html=self.config.email.email_registration_template_html,
+ template_text=self.config.email.email_registration_template_text,
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if (
+ self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
+ ):
logger.warning(
"Email registration has been disabled due to lack of email config"
)
@@ -137,7 +139,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
@@ -259,9 +261,9 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
- self.config.email_registration_template_failure_html
+ self.config.email.email_registration_template_failure_html
)
async def on_GET(self, request: Request, medium: str) -> None:
@@ -269,8 +271,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
raise SynapseError(
400, "This medium is currently not supported for registration"
)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"User registration via email has been disabled due to lack of email config"
)
@@ -303,7 +305,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_registration_template_success_html_content
+ html = self.config.email.email_registration_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
@@ -897,12 +899,12 @@ def _calculate_registration_flows(
flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
# Prepend m.login.terms to all flows if we're requiring consent
- if config.user_consent_at_registration:
+ if config.consent.user_consent_at_registration:
for flow in flows:
flow.insert(0, LoginType.TERMS)
# Prepend recaptcha to all flows if we're requiring captcha
- if config.enable_registration_captcha:
+ if config.captcha.enable_registration_captcha:
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 8852811114..a47d9bd01d 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectorySearchRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- if not self.hs.config.user_directory_search_enabled:
+ if not self.hs.config.userdirectory.user_directory_search_enabled:
return 200, {"limited": False, "results": []}
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index a1a815cf82..b52a296d8f 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -42,15 +42,15 @@ class VersionsRestServlet(RestServlet):
# Calculate these once since they shouldn't change after start-up.
self.e2ee_forced_public = (
RoomCreationPreset.PUBLIC_CHAT
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
self.e2ee_forced_private = (
RoomCreationPreset.PRIVATE_CHAT
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
self.e2ee_forced_trusted_private = (
RoomCreationPreset.TRUSTED_PRIVATE_CHAT
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index 9d46ed3af3..ea2b8aa45f 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -37,14 +37,14 @@ class VoipRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
- request, self.hs.config.turn_allow_guests
+ request, self.hs.config.voip.turn_allow_guests
)
- turnUris = self.hs.config.turn_uris
- turnSecret = self.hs.config.turn_shared_secret
- turnUsername = self.hs.config.turn_username
- turnPassword = self.hs.config.turn_password
- userLifetime = self.hs.config.turn_user_lifetime
+ turnUris = self.hs.config.voip.turn_uris
+ turnSecret = self.hs.config.voip.turn_shared_secret
+ turnUsername = self.hs.config.voip.turn_username
+ turnPassword = self.hs.config.voip.turn_password
+ userLifetime = self.hs.config.voip.turn_user_lifetime
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 06e0fbde22..3d2afacc50 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -84,14 +84,15 @@ class ConsentResource(DirectServeHtmlResource):
# this is required by the request_handler wrapper
self.clock = hs.get_clock()
- self._default_consent_version = hs.config.user_consent_version
- if self._default_consent_version is None:
+ # Consent must be configured to create this resource.
+ default_consent_version = hs.config.consent.user_consent_version
+ consent_template_directory = hs.config.consent.user_consent_template_dir
+ if default_consent_version is None or consent_template_directory is None:
raise ConfigError(
"Consent resource is enabled but user_consent section is "
"missing in config file."
)
-
- consent_template_directory = hs.config.user_consent_template_dir
+ self._default_consent_version = default_consent_version
# TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(consent_template_directory)
@@ -99,13 +100,13 @@ class ConsentResource(DirectServeHtmlResource):
loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
)
- if hs.config.form_secret is None:
+ if hs.config.key.form_secret is None:
raise ConfigError(
"Consent resource is enabled but form_secret is not set in "
"config file. It should be set to an arbitrary secret string."
)
- self._hmac_secret = hs.config.form_secret.encode("utf-8")
+ self._hmac_secret = hs.config.key.form_secret.encode("utf-8")
async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index ebe243bcfd..12b3ae120c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -70,19 +70,19 @@ class LocalKey(Resource):
Resource.__init__(self)
def update_response_body(self, time_now_msec: int) -> None:
- refresh_interval = self.config.key_refresh_interval
+ refresh_interval = self.config.key.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self) -> JsonDict:
verify_keys = {}
- for key in self.config.signing_key:
+ for key in self.config.key.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)}
old_verify_keys = {}
- for key_id, key in self.config.old_signing_keys.items():
+ for key_id, key in self.config.key.old_signing_keys.items():
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
"key": encode_base64(verify_key_bytes),
@@ -95,13 +95,13 @@ class LocalKey(Resource):
"verify_keys": verify_keys,
"old_verify_keys": old_verify_keys,
}
- for key in self.config.signing_key:
+ for key in self.config.key.signing_key:
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
- if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
+ if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
return respond_with_json_bytes(request, 200, self.response_body)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index d8fd7938a4..3923ba8439 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
-from twisted.web.server import Request
-
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -97,10 +96,12 @@ class RemoteKey(DirectServeJsonResource):
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
- self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.federation_domain_whitelist = (
+ hs.config.federation.federation_domain_whitelist
+ )
self.config = hs.config
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
@@ -117,7 +118,7 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request: Request) -> None:
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
@@ -126,7 +127,7 @@ class RemoteKey(DirectServeJsonResource):
async def query_keys(
self,
- request: Request,
+ request: SynapseRequest,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
@@ -235,7 +236,7 @@ class RemoteKey(DirectServeJsonResource):
signed_keys = []
for key_json in json_results:
key_json = json_decoder.decode(key_json.decode("utf-8"))
- for signing_key in self.config.key_server_signing_keys:
+ for signing_key in self.config.key.key_server_signing_keys:
key_json = sign_json(
key_json, self.config.server.server_name, signing_key
)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 7c881f2bdb..014fa893d6 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
+from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii
@@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
)
-def respond_404(request: Request) -> None:
+def respond_404(request: SynapseRequest) -> None:
respond_with_json(
request,
404,
@@ -84,7 +85,7 @@ def respond_404(request: Request) -> None:
async def respond_with_file(
- request: Request,
+ request: SynapseRequest,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
@@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
async def respond_with_responder(
- request: Request,
+ request: SynapseRequest,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index a1d36e5cf1..a95804d327 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -16,8 +16,6 @@
from typing import TYPE_CHECKING
-from twisted.web.server import Request
-
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest
@@ -33,11 +31,11 @@ class MediaConfigResource(DirectServeJsonResource):
config = hs.config
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- self.limits_dict = {"m.upload.size": config.max_upload_size}
+ self.limits_dict = {"m.upload.size": config.media.max_upload_size}
async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- async def _async_render_OPTIONS(self, request: Request) -> None:
+ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d6d938953e..6180fa575e 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,10 +15,9 @@
import logging
from typing import TYPE_CHECKING
-from twisted.web.server import Request
-
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
+from synapse.http.site import SynapseRequest
from ._base import parse_media_id, respond_404
@@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
self.media_repo = media_repo
self.server_name = hs.hostname
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 39bbe4e874..08bd85f664 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -195,23 +195,24 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
- def url_cache_thumbnail_directory(self, media_id: str) -> str:
+ def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
- return os.path.join(
- self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
- )
+ return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:])
else:
return os.path.join(
- self.base_path,
"url_cache_thumbnails",
media_id[0:2],
media_id[2:4],
media_id[4:],
)
+ url_cache_thumbnail_directory = _wrap_in_base_path(
+ url_cache_thumbnail_directory_rel
+ )
+
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 50e4c9e29f..abd88a2d4f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -23,7 +23,6 @@ import twisted.internet.error
import twisted.web.http
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
-from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,
@@ -34,6 +33,7 @@ from synapse.api.errors import (
)
from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
+from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@@ -76,29 +76,35 @@ class MediaRepository:
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
- self.max_upload_size = hs.config.max_upload_size
- self.max_image_pixels = hs.config.max_image_pixels
+ self.max_upload_size = hs.config.media.max_upload_size
+ self.max_image_pixels = hs.config.media.max_image_pixels
Thumbnailer.set_limits(self.max_image_pixels)
- self.primary_base_path: str = hs.config.media_store_path
+ self.primary_base_path: str = hs.config.media.media_store_path
self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
- self.dynamic_thumbnails = hs.config.dynamic_thumbnails
- self.thumbnail_requirements = hs.config.thumbnail_requirements
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+ self.thumbnail_requirements = hs.config.media.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
self.recently_accessed_locals: Set[str] = set()
- self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.federation_domain_whitelist = (
+ hs.config.federation.federation_domain_whitelist
+ )
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
- for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ for (
+ clz,
+ provider_config,
+ wrapper_config,
+ ) in hs.config.media.media_storage_providers:
backend = clz(hs, provider_config)
provider = StorageProviderWrapper(
backend,
@@ -187,7 +193,7 @@ class MediaRepository:
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(
- self, request: Request, media_id: str, name: Optional[str]
+ self, request: SynapseRequest, media_id: str, name: Optional[str]
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -221,7 +227,11 @@ class MediaRepository:
)
async def get_remote_media(
- self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ name: Optional[str],
) -> None:
"""Respond to requests for remote media.
@@ -969,7 +979,7 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
- if not hs.config.can_load_media_repo:
+ if not hs.config.media.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
super().__init__()
@@ -980,7 +990,7 @@ class MediaRepositoryResource(Resource):
self.putChild(
b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
)
- if hs.config.url_preview_enabled:
+ if hs.config.media.url_preview_enabled:
self.putChild(
b"preview_url",
PreviewUrlResource(hs, media_repo, media_repo.media_storage),
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 01fada8fb5..fca239d8c7 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -132,8 +132,7 @@ class MediaStorage:
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
finished_called = [False]
@@ -244,8 +243,7 @@ class MediaStorage:
return legacy_local_path
dirname = os.path.dirname(local_path)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
for provider in self.storage_providers:
res: Any = await provider.fetch(path, file_info)
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 8b74e72655..e04671fb95 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional
import attr
@@ -22,6 +22,8 @@ from synapse.types import JsonDict
from synapse.util import json_decoder
if TYPE_CHECKING:
+ from lxml import etree
+
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
class OEmbedResult:
# The Open Graph result (converted from the oEmbed result).
open_graph_result: JsonDict
- # Number of seconds to cache the content, according to the oEmbed response.
+ # Number of milliseconds to cache the content, according to the oEmbed response.
#
# This will be None if no cache-age is provided in the oEmbed response (or
# if the oEmbed response cannot be turned into an Open Graph response).
@@ -119,10 +121,22 @@ class OEmbedProvider:
# Ensure the cache age is None or an int.
cache_age = oembed.get("cache_age")
if cache_age:
- cache_age = int(cache_age)
+ cache_age = int(cache_age) * 1000
# The results.
- open_graph_response = {"og:title": oembed.get("title")}
+ open_graph_response = {
+ "og:url": url,
+ }
+
+ # Use either title or author's name as the title.
+ title = oembed.get("title") or oembed.get("author_name")
+ if title:
+ open_graph_response["og:title"] = title
+
+ # Use the provider name and as the site.
+ provider_name = oembed.get("provider_name")
+ if provider_name:
+ open_graph_response["og:site_name"] = provider_name
# If a thumbnail exists, use it. Note that dimensions will be calculated later.
if "thumbnail_url" in oembed:
@@ -137,6 +151,15 @@ class OEmbedProvider:
# If this is a photo, use the full image, not the thumbnail.
open_graph_response["og:image"] = oembed["url"]
+ elif oembed_type == "video":
+ open_graph_response["og:type"] = "video.other"
+ calc_description_and_urls(open_graph_response, oembed["html"])
+ open_graph_response["og:video:width"] = oembed["width"]
+ open_graph_response["og:video:height"] = oembed["height"]
+
+ elif oembed_type == "link":
+ open_graph_response["og:type"] = "website"
+
else:
raise RuntimeError(f"Unknown oEmbed type: {oembed_type}")
@@ -149,6 +172,14 @@ class OEmbedProvider:
return OEmbedResult(open_graph_response, cache_age)
+def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
+ results = []
+ for tag in tree.xpath("//*/" + tag_name):
+ if "src" in tag.attrib:
+ results.append(tag.attrib["src"])
+ return results
+
+
def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None:
"""
Calculate description for an HTML document.
@@ -179,6 +210,16 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
if tree is None:
return
+ # Attempt to find interesting URLs (images, videos, embeds).
+ if "og:image" not in open_graph_response:
+ image_urls = _fetch_urls(tree, "img")
+ if image_urls:
+ open_graph_response["og:image"] = image_urls[0]
+
+ video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed")
+ if video_urls:
+ open_graph_response["og:video"] = video_urls[0]
+
from synapse.rest.media.v1.preview_url_resource import _calc_description
description = _calc_description(tree)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0a0b476d2b..79a42b2455 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -29,7 +29,6 @@ import attr
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
-from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -126,14 +125,14 @@ class PreviewUrlResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.filepaths = media_repo.filepaths
- self.max_spider_size = hs.config.max_spider_size
+ self.max_spider_size = hs.config.media.max_spider_size
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.client = SimpleHttpClient(
hs,
treq_args={"browser_like_redirects": True},
- ip_whitelist=hs.config.url_preview_ip_range_whitelist,
- ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ ip_whitelist=hs.config.media.url_preview_ip_range_whitelist,
+ ip_blacklist=hs.config.media.url_preview_ip_range_blacklist,
use_proxy=True,
)
self.media_repo = media_repo
@@ -151,8 +150,8 @@ class PreviewUrlResource(DirectServeJsonResource):
or instance_running_jobs == hs.get_instance_name()
)
- self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
- self.url_preview_accept_language = hs.config.url_preview_accept_language
+ self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist
+ self.url_preview_accept_language = hs.config.media.url_preview_accept_language
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
@@ -168,7 +167,7 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
- async def _async_render_OPTIONS(self, request: Request) -> None:
+ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
@@ -305,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource):
with open(media_info.filename, "rb") as file:
body = file.read()
- oembed_response = self._oembed.parse_oembed_response(media_info.uri, body)
+ oembed_response = self._oembed.parse_oembed_response(url, body)
og = oembed_response.open_graph_result
# Use the cache age from the oEmbed result, instead of the HTTP response.
@@ -486,7 +485,6 @@ class PreviewUrlResource(DirectServeJsonResource):
async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails."""
- # TODO: Delete from backup media store
assert self._worker_run_media_background_jobs
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 6c9969e55f..18bf977d3d 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -93,6 +93,11 @@ class StorageProviderWrapper(StorageProvider):
if file_info.server_name and not self.store_remote:
return None
+ if file_info.url_cache:
+ # The URL preview cache is short lived and not worth offloading or
+ # backing up.
+ return None
+
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
@@ -110,6 +115,11 @@ class StorageProviderWrapper(StorageProvider):
run_in_background(store)
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+ if file_info.url_cache:
+ # Files in the URL preview cache definitely aren't stored here,
+ # so avoid any potentially slow I/O or network access.
+ return None
+
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
@@ -125,7 +135,7 @@ class FileStorageProviderBackend(StorageProvider):
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
- self.cache_directory = hs.config.media_store_path
+ self.cache_directory = hs.config.media.media_store_path
self.base_directory = config
def __str__(self) -> str:
@@ -138,8 +148,7 @@ class FileStorageProviderBackend(StorageProvider):
backup_fname = os.path.join(self.base_directory, path)
dirname = os.path.dirname(backup_fname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 22f43d8531..ed91ef5a42 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,11 +17,10 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-from twisted.web.server import Request
-
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
@@ -54,10 +53,10 @@ class ThumbnailResource(DirectServeJsonResource):
self.store = hs.get_datastore()
self.media_repo = media_repo
self.media_storage = media_storage
- self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.server_name = hs.hostname
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_local_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
media_id: str,
width: int,
height: int,
@@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
@@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
@@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_remote_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
@@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_and_respond_with_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 146adca8f1..7dcb1428e4 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -16,8 +16,6 @@
import logging
from typing import IO, TYPE_CHECKING, Dict, List, Optional
-from twisted.web.server import Request
-
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args
@@ -43,10 +41,10 @@ class UploadResource(DirectServeJsonResource):
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.auth = hs.get_auth()
- self.max_upload_size = hs.config.max_upload_size
+ self.max_upload_size = hs.config.media.max_upload_size
self.clock = hs.get_clock()
- async def _async_render_OPTIONS(self, request: Request) -> None:
+ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: SynapseRequest) -> None:
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index 47a2f72b32..6ad558f5d1 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -45,12 +45,12 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
# provider-specific SSO bits. Only load these if they are enabled, since they
# rely on optional dependencies.
- if hs.config.oidc_enabled:
+ if hs.config.oidc.oidc_enabled:
from synapse.rest.synapse.client.oidc import OIDCResource
resources["/_synapse/client/oidc"] = OIDCResource(hs)
- if hs.config.saml2_enabled:
+ if hs.config.saml2.saml2_enabled:
from synapse.rest.synapse.client.saml2 import SAML2Resource
res = SAML2Resource(hs)
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index f2800bf2db..28a67f04e3 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -47,20 +47,20 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
self.store = hs.get_datastore()
self._local_threepid_handling_disabled_due_to_email_config = (
- hs.config.local_threepid_handling_disabled_due_to_email_config
+ hs.config.email.local_threepid_handling_disabled_due_to_email_config
)
self._confirmation_email_template = (
- hs.config.email_password_reset_template_confirmation_html
+ hs.config.email.email_password_reset_template_confirmation_html
)
self._email_password_reset_template_success_html = (
- hs.config.email_password_reset_template_success_html_content
+ hs.config.email.email_password_reset_template_success_html_content
)
self._failure_email_template = (
- hs.config.email_password_reset_template_failure_html
+ hs.config.email.email_password_reset_template_failure_html
)
# This resource should not be mounted if threepid behaviour is not LOCAL
- assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
sid = parse_string(request, "sid", required=True)
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 64378ed57b..d8eae3970d 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
- self.sp_config = hs.config.saml2_sp_config
+ self.sp_config = hs.config.saml2.saml2_sp_config
def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string(
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 4e0f814035..e09a25591f 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -36,9 +36,11 @@ class ConsentServerNotices:
self._users_in_progress: Set[str] = set()
- self._current_consent_version = hs.config.user_consent_version
- self._server_notice_content = hs.config.user_consent_server_notice_content
- self._send_to_guests = hs.config.user_consent_server_notice_to_guests
+ self._current_consent_version = hs.config.consent.user_consent_version
+ self._server_notice_content = (
+ hs.config.consent.user_consent_server_notice_content
+ )
+ self._send_to_guests = hs.config.consent.user_consent_server_notice_to_guests
if self._server_notice_content is not None:
if not self._server_notices_manager.is_enabled():
@@ -63,6 +65,9 @@ class ConsentServerNotices:
# not enabled
return
+ # A consent version must be given.
+ assert self._current_consent_version is not None
+
# make sure we don't send two messages to the same user at once
if user_id in self._users_in_progress:
return
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index d87a538917..cd1c5ff6f4 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -39,7 +39,7 @@ class ServerNoticesManager:
self._server_name = hs.hostname
self._notifier = hs.get_notifier()
- self.server_notices_mxid = self._config.server_notices_mxid
+ self.server_notices_mxid = self._config.servernotices.server_notices_mxid
def is_enabled(self):
"""Checks if server notices are enabled on this server.
@@ -47,7 +47,7 @@ class ServerNoticesManager:
Returns:
bool
"""
- return self._config.server_notices_mxid is not None
+ return self.server_notices_mxid is not None
async def send_notice(
self,
@@ -71,9 +71,9 @@ class ServerNoticesManager:
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
- system_mxid = self._config.server_notices_mxid
+ assert self.server_notices_mxid is not None
requester = create_requester(
- system_mxid, authenticated_entity=self._server_name
+ self.server_notices_mxid, authenticated_entity=self._server_name
)
logger.info("Sending server notice to %s", user_id)
@@ -81,7 +81,7 @@ class ServerNoticesManager:
event_dict = {
"type": type,
"room_id": room_id,
- "sender": system_mxid,
+ "sender": self.server_notices_mxid,
"content": event_content,
}
@@ -106,7 +106,7 @@ class ServerNoticesManager:
Returns:
room id of notice room.
"""
- if not self.is_enabled():
+ if self.server_notices_mxid is None:
raise Exception("Server notices not enabled")
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
@@ -139,12 +139,12 @@ class ServerNoticesManager:
# avatar, we have to use both.
join_profile = None
if (
- self._config.server_notices_mxid_display_name is not None
- or self._config.server_notices_mxid_avatar_url is not None
+ self._config.servernotices.server_notices_mxid_display_name is not None
+ or self._config.servernotices.server_notices_mxid_avatar_url is not None
):
join_profile = {
- "displayname": self._config.server_notices_mxid_display_name,
- "avatar_url": self._config.server_notices_mxid_avatar_url,
+ "displayname": self._config.servernotices.server_notices_mxid_display_name,
+ "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url,
}
requester = create_requester(
@@ -154,7 +154,7 @@ class ServerNoticesManager:
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
- "name": self._config.server_notices_room_name,
+ "name": self._config.servernotices.server_notices_room_name,
"power_level_content_override": {"users_default": -10},
},
ratelimit=False,
@@ -178,6 +178,7 @@ class ServerNoticesManager:
user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to.
"""
+ assert self.server_notices_mxid is not None
requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 463ce58dae..c981df3f18 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -263,7 +263,9 @@ class StateHandler:
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
) -> EventContext:
- """Build an EventContext structure for the event.
+ """Build an EventContext structure for a non-outlier event.
+
+ (for an outlier, call EventContext.for_outlier directly)
This works out what the current state should be for the event, and
generates a new state group if necessary.
@@ -278,35 +280,7 @@ class StateHandler:
The event context.
"""
- if event.internal_metadata.is_outlier():
- # If this is an outlier, then we know it shouldn't have any current
- # state. Certainly store.get_current_state won't return any, and
- # persisting the event won't store the state group.
-
- # FIXME: why do we populate current_state_ids? I thought the point was
- # that we weren't supposed to have any state for outliers?
- if old_state:
- prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
- if event.is_state():
- current_state_ids = dict(prev_state_ids)
- key = (event.type, event.state_key)
- current_state_ids[key] = event.event_id
- else:
- current_state_ids = prev_state_ids
- else:
- current_state_ids = {}
- prev_state_ids = {}
-
- # We don't store state for outliers, so we don't generate a state
- # group for it.
- context = EventContext.with_state(
- state_group=None,
- state_group_before_event=None,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- )
-
- return context
+ assert not event.internal_metadata.is_outlier()
#
# first of all, figure out the state before the event
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index d0cf3460da..70ca3e09f7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -324,7 +324,7 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return ({}, {})
+ return {}, {}
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e2d1b758bd..2da2659f41 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -60,7 +60,7 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
- hs.hostname, hs.config.app_service_config_files
+ hs.hostname, hs.config.appservice.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 2712514145..dafba2b03f 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -555,8 +555,11 @@ class ClientIpStore(ClientIpWorkerStore):
return ret
async def get_user_ip_and_agents(
- self, user: UserID
+ self, user: UserID, since_ts: int = 0
) -> List[Dict[str, Union[str, int]]]:
+ """
+ Fetch IP/User Agent connection since a given timestamp.
+ """
user_id = user.to_string()
results = {}
@@ -568,13 +571,23 @@ class ClientIpStore(ClientIpWorkerStore):
) = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
- results[(access_token, ip)] = (user_agent, last_seen)
+ if last_seen >= since_ts:
+ results[(access_token, ip)] = (user_agent, last_seen)
- rows = await self.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "last_seen"],
- desc="get_user_ip_and_agents",
+ def get_recent(txn):
+ txn.execute(
+ """
+ SELECT access_token, ip, user_agent, last_seen FROM user_ips
+ WHERE last_seen >= ? AND user_id = ?
+ ORDER BY last_seen
+ DESC
+ """,
+ (since_ts, user_id),
+ )
+ return txn.fetchall()
+
+ rows = await self.db_pool.runInteraction(
+ desc="get_user_ip_and_agents", func=get_recent
)
results.update(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index c55508867d..3154906d45 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -136,7 +136,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id, last_stream_id
)
if not has_changed:
- return ([], current_stream_id)
+ return [], current_stream_id
def get_new_messages_for_device_txn(txn):
sql = (
@@ -240,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
- return ([], current_stream_id)
+ return [], current_stream_id
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
- return ([], last_stream_id)
+ return [], last_stream_id
@trace
def get_new_messages_for_remote_destination_txn(txn):
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1f0a39eac4..a95ac34f09 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -824,6 +824,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
if otk_row is None:
return None
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d72e716b5c..4a1a2f4a6a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1495,7 +1495,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+ return int(res["topological_ordering"]), int(res["stream_ordering"])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index d213b26703..b76ee51a9b 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -63,7 +63,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
- `config.track_appservice_user_ips` must be set to `true` for this
+ `config.appservice.track_appservice_user_ips` must be set to `true` for this
method to return anything other than native matrix users.
Returns:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bccff5e5b9..3eb30944bf 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -102,15 +102,19 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
(room_id,),
)
rows = txn.fetchall()
- max_depth = max(row[1] for row in rows)
-
- if max_depth < token.topological:
- # We need to ensure we don't delete all the events from the database
- # otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremities)
- raise SynapseError(
- 400, "topological_ordering is greater than forward extremeties"
- )
+ # if we already have no forwards extremities (for example because they were
+ # cleared out by the `delete_old_current_state_events` background database
+ # update), then we may as well carry on.
+ if rows:
+ max_depth = max(row[1] for row in rows)
+
+ if max_depth < token.topological:
+ # We need to ensure we don't delete all the events from the database
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremities)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremities"
+ )
logger.info("[purge] looking for events to delete")
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fafadb88fc..c83089ee63 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -388,7 +388,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"get_users_expiring_soon",
select_users_txn,
self._clock.time_msec(),
- self.config.account_validity_renew_at,
+ self.config.account_validity.account_validity_renew_at,
)
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -2015,7 +2015,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
(user_id_obj.localpart, create_profile_with_displayname),
)
- if self.hs.config.stats_enabled:
+ if self.hs.config.stats.stats_enabled:
# we create a new completed user statistics row
# we don't strictly need current_token since this user really can't
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a4ec6bc328..ddb162a4fc 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -82,7 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if (
self.hs.config.worker.run_background_tasks
- and self.hs.config.metrics_flags.known_servers
+ and self.hs.config.metrics.metrics_flags.known_servers
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1c642c753b..9eb74a81a0 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,12 +15,12 @@
import logging
import re
from collections import namedtuple
-from typing import Collection, List, Optional, Set
+from typing import Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -32,14 +32,24 @@ SearchEntry = namedtuple(
)
+def _clean_value_for_search(value: str) -> str:
+ """
+ Replaces any null code points in the string with spaces as
+ Postgres and SQLite do not like the insertion of strings with
+ null code points into the full-text search tables.
+ """
+ return value.replace("\u0000", " ")
+
+
class SearchWorkerStore(SQLBaseStore):
- def store_search_entries_txn(self, txn, entries):
+ def store_search_entries_txn(
+ self, txn: LoggingTransaction, entries: Iterable[SearchEntry]
+ ) -> None:
"""Add entries to the search table
Args:
- txn (cursor):
- entries (iterable[SearchEntry]):
- entries to be added to the table
+ txn:
+ entries: entries to be added to the table
"""
if not self.hs.config.enable_search:
return
@@ -55,7 +65,7 @@ class SearchWorkerStore(SQLBaseStore):
entry.event_id,
entry.room_id,
entry.key,
- entry.value,
+ _clean_value_for_search(entry.value),
entry.stream_ordering,
entry.origin_server_ts,
)
@@ -70,11 +80,16 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,?)"
)
args = (
- (entry.event_id, entry.room_id, entry.key, entry.value)
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ _clean_value_for_search(entry.value),
+ )
for entry in entries
)
-
txn.execute_batch(sql, args)
+
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -646,6 +661,7 @@ class SearchStore(SearchBackgroundUpdateStore):
for key in ("body", "name", "topic"):
v = event.content.get(key, None)
if v:
+ v = _clean_value_for_search(v)
values.append(v)
if not values:
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index bff7d0404f..a89747d741 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -58,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return (max_stream_id, [])
+ return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 343d6efc92..e20033bb28 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -98,7 +98,7 @@ class StatsStore(StateDeltasStore):
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
- self.stats_enabled = hs.config.stats_enabled
+ self.stats_enabled = hs.config.stats.stats_enabled
self.stats_delta_processing_lock = DeferredLock()
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 959f13de47..dc7884b1c0 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,6 +39,8 @@ import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from frozendict import frozendict
+
from twisted.internet import defer
from synapse.api.filtering import Filter
@@ -379,7 +381,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, positions)
+ return RoomStreamToken(None, min_pos, frozendict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -622,7 +624,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._set_before_and_after(events, rows)
- return (events, token)
+ return events, token
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: RoomStreamToken
@@ -1240,7 +1242,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._set_before_and_after(events, rows)
- return (events, token)
+ return events, token
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 718f3e9976..90d65edc42 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,14 +14,28 @@
import logging
import re
-from typing import Any, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ cast,
+)
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
- async def _populate_user_directory_createtables(self, progress, batch_size):
+ async def _populate_user_directory_createtables(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# Get all the rooms that we want to process.
- def _make_staging_area(txn):
+ def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
@@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
- async def _populate_user_directory_cleanup(self, progress, batch_size):
+ async def _populate_user_directory_cleanup(
+ self,
+ progress: JsonDict,
+ batch_size: int,
+ ) -> int:
"""
Update the user directory stream position, then clean up the old tables.
"""
position = await self.db_pool.simple_select_one_onecol(
- TEMP_TABLE + "_position", None, "position"
+ TEMP_TABLE + "_position", {}, "position"
)
await self.update_user_directory_stream_pos(position)
- def _delete_staging_area(txn):
+ def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
- async def _populate_user_directory_process_rooms(self, progress, batch_size):
+ async def _populate_user_directory_process_rooms(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
+ Rescan the state of all rooms so we can track
+
+ - who's in a public room;
+ - which local users share a private room with other users (local
+ and remote); and
+ - who should be in the user_directory.
+
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
+
+ Returns:
+ number of events processed.
"""
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()
- def _get_next_batch(txn):
+ def _get_next_batch(
+ txn: LoggingTransaction,
+ ) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
@@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
TEMP_TABLE + "_rooms",
)
txn.execute(sql)
- rooms_to_work_on = txn.fetchall()
+ rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
if not rooms_to_work_on:
return None
@@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
- progress["remaining"] = txn.fetchone()[0]
+ result = txn.fetchone()
+ assert result is not None
+ progress["remaining"] = result[0]
return rooms_to_work_on
@@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
- async def _populate_user_directory_process_users(self, progress, batch_size):
+ async def _populate_user_directory_process_users(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
Add all local users to the user directory.
"""
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users",
str(batch_size),
)
txn.execute(sql)
- users_to_work_on = txn.fetchall()
+ user_result = cast(List[Tuple[str]], txn.fetchall())
- if not users_to_work_on:
+ if not user_result:
return None
- users_to_work_on = [x[0] for x in users_to_work_on]
+ users_to_work_on = [x[0] for x in user_result]
# Get how many are left to process, so we can give status on how
# far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql)
- progress["remaining"] = txn.fetchone()[0]
+ count_result = txn.fetchone()
+ assert count_result is not None
+ progress["remaining"] = count_result[0]
return users_to_work_on
@@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
- async def is_room_world_readable_or_publicly_joinable(self, room_id):
+ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
@@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if not isinstance(avatar_url, str):
avatar_url = None
- def _update_profile_in_user_dir_txn(txn):
+ def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
@@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id, other_user_id in user_id_tuples
],
value_names=(),
- value_values=None,
+ value_values=(),
desc="add_users_who_share_room",
)
@@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
- value_values=None,
+ value_values=(),
desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory"""
- def _delete_all_from_user_dir_txn(txn):
+ def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
@@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
@@ -497,16 +542,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ) -> None:
super().__init__(database, db_conn, hs)
self._prefer_local_users_in_search = (
- hs.config.user_directory_search_prefer_local_users
+ hs.config.userdirectory.user_directory_search_prefer_local_users
)
self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None:
- def _remove_from_user_dir_txn(txn):
+ def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
@@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn
)
- async def get_users_in_dir_due_to_room(self, room_id):
+ async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
@@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
room_id
"""
- def _remove_user_who_share_room_txn(txn):
+ def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
@@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- async def get_user_dir_rooms_user_is_in(self, user_id):
+ async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
"""
Returns the rooms that a user is in.
@@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
- def _get_shared_rooms_for_users_txn(txn):
+ def _get_shared_rooms_for_users_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, str]]:
txn.execute(
"""
SELECT p1.room_id
@@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
- async def search_user_dir(self, user_id, search_term, limit):
+ async def search_user_dir(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -687,7 +741,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
}
"""
- if self.hs.config.user_directory_search_all_users:
+ if self.hs.config.userdirectory.user_directory_search_all_users:
join_args = (user_id,)
where_clause = "user_id != ?"
else:
@@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
- ordering_arguments = ()
+ ordering_arguments: Tuple[str, ...] = ()
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results}
-def _parse_query_sqlite(search_term):
+def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
@@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
-def _parse_query_postgres(search_term):
+def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d4754c904c..f31880b8ec 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -545,7 +545,7 @@ def _apply_module_schemas(
database_engine:
config: application config
"""
- for (mod, _config) in config.password_providers:
+ for (mod, _config) in config.authproviders.password_providers:
if not hasattr(mod, "get_db_schema_files"):
continue
modname = ".".join((mod.__module__, mod.__name__))
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index aa2ce44c6c..573e05a482 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -27,11 +27,22 @@ for more information on how this works.
Changes in SCHEMA_VERSION = 61:
- The `user_stats_historical` and `room_stats_historical` tables are not written and
are not read (previously, they were written but not read).
+ - MSC2716: Add `insertion_events` and `insertion_event_edges` tables to keep track
+ of insertion events in order to navigate historical chunks of messages.
+ - MSC2716: Add `chunk_events` table to track how the chunk is labeled and
+ determines which insertion event it points to.
+
+Changes in SCHEMA_VERSION = 62:
+ - MSC2716: Add `insertion_event_extremities` table that keeps track of which
+ insertion events need to be backfilled.
Changes in SCHEMA_VERSION = 63:
- The `public_room_list_stream` table is not written nor read to
(previously, it was written and read to, but not for any significant purpose).
https://github.com/matrix-org/synapse/pull/10565
+
+Changes in SCHEMA_VERSION = 64:
+ - MSC2716: Rename related tables and columns from "chunks" to "batches".
"""
diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py
index 8a1f340083..22a7901e15 100644
--- a/synapse/storage/schema/main/delta/30/as_users.py
+++ b/synapse/storage/schema/main/delta/30/as_users.py
@@ -33,7 +33,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
config_files = []
try:
- config_files = config.app_service_config_files
+ config_files = config.appservice.app_service_config_files
except AttributeError:
logger.warning("Could not get app_service_config_files from config")
pass
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index cf4005984b..c08d591f29 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -81,7 +81,7 @@ class PaginationConfig:
raise SynapseError(400, "Invalid request.")
def __repr__(self) -> str:
- return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
+ return "PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)" % (
self.from_token,
self.to_token,
self.direction,
diff --git a/synapse/types.py b/synapse/types.py
index 90168ce8fa..364ecf7d45 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -30,6 +30,7 @@ from typing import (
)
import attr
+from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -457,6 +458,9 @@ class RoomStreamToken:
Note: The `RoomStreamToken` cannot have both a topological part and an
instance map.
+
+ For caching purposes, `RoomStreamToken`s and by extension, all their
+ attributes, must be hashable.
"""
topological = attr.ib(
@@ -466,12 +470,12 @@ class RoomStreamToken:
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
instance_map = attr.ib(
- type=Dict[str, int],
- factory=dict,
+ type="frozendict[str, int]",
+ factory=frozendict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
value_validator=attr.validators.instance_of(int),
- mapping_validator=attr.validators.instance_of(dict),
+ mapping_validator=attr.validators.instance_of(frozendict),
),
)
@@ -507,7 +511,7 @@ class RoomStreamToken:
return cls(
topological=None,
stream=stream,
- instance_map=instance_map,
+ instance_map=frozendict(instance_map),
)
except Exception:
pass
@@ -540,7 +544,7 @@ class RoomStreamToken:
for instance in set(self.instance_map).union(other.instance_map)
}
- return RoomStreamToken(None, max_stream, instance_map)
+ return RoomStreamToken(None, max_stream, frozendict(instance_map))
def as_historical_tuple(self) -> Tuple[int, int]:
"""Returns a tuple of `(topological, stream)` for historical tokens.
@@ -552,7 +556,7 @@ class RoomStreamToken:
"Cannot call `RoomStreamToken.as_historical_tuple` on live token"
)
- return (self.topological, self.stream)
+ return self.topological, self.stream
def get_stream_pos_for_instance(self, instance_name: str) -> int:
"""Get the stream position that the given writer was at at this token.
@@ -593,6 +597,12 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True)
class StreamToken:
+ """A collection of positions within multiple streams.
+
+ For caching purposes, `StreamToken`s and by extension, all their attributes,
+ must be hashable.
+ """
+
room_key = attr.ib(
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
)
@@ -756,7 +766,7 @@ def get_verify_key_from_cross_signing_key(key_info):
raise ValueError("Invalid key")
# and return that one key
for key_id, key_data in keys.items():
- return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
+ return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
@attr.s(auto_attribs=True, frozen=True, slots=True)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 8ac3eab2f5..4938ddf703 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -21,13 +21,28 @@ from typing import (
Iterable,
Iterator,
Mapping,
- Sequence,
Set,
+ Sized,
Tuple,
TypeVar,
)
+from typing_extensions import Protocol
+
T = TypeVar("T")
+S = TypeVar("S", bound="_SelfSlice")
+
+
+class _SelfSlice(Sized, Protocol):
+ """A helper protocol that matches types where taking a slice results in the
+ same type being returned.
+
+ This is more specific than `Sequence`, which allows another `Sequence` to be
+ returned.
+ """
+
+ def __getitem__(self: S, i: slice) -> S:
+ ...
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
@@ -46,7 +61,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
return iter(lambda: tuple(islice(sourceiter, size)), ())
-def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
+def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]:
"""Split the given sequence into chunks of the given size
The last chunk may be shorter than the given size.
|