diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8454d74858..6014adc850 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -23,8 +23,7 @@ import math
import os
import resource
import sys
-
-from six import iteritems
+from typing import Iterable
from prometheus_client import Gauge
@@ -50,12 +49,14 @@ from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import (
OptionsResource,
RootOptionsRedirectResource,
RootRedirect,
+ StaticResource,
)
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
@@ -67,6 +68,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
+from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
@@ -89,24 +91,26 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
- def _listener_http(self, config, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- tls = listener_config.get("tls", False)
- site_tag = listener_config.get("tag", port)
+ def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
+ tls = listener_config.tls
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "openid" and "federation" in res["names"]:
+ # We always include a health resource.
+ resources = {"/health": HealthResource()}
+
+ for res in listener_config.http_options.resources:
+ for name in res.names:
+ if name == "openid" and "federation" in res.names:
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
- resources.update(
- self._configure_named_resource(name, res.get("compress", False))
- )
+ resources.update(self._configure_named_resource(name, res.compress))
- additional_resources = listener_config.get("additional_resources", {})
+ additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
@@ -228,7 +232,7 @@ class SynapseHomeServer(HomeServer):
if name in ["static", "client"]:
resources.update(
{
- STATIC_PREFIX: File(
+ STATIC_PREFIX: StaticResource(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
}
@@ -278,7 +282,7 @@ class SynapseHomeServer(HomeServer):
return resources
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
config = self.get_config()
if config.redis_enabled:
@@ -288,25 +292,25 @@ class SynapseHomeServer(HomeServer):
self.get_tcp_replication().start_replication(self)
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listening_services.extend(self._listener_http(config, listener))
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "replication":
+ elif listener.type == "replication":
services = listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
ReplicationStreamProtocolFactory(self),
)
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -315,9 +319,11 @@ class SynapseHomeServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ # this shouldn't happen, as the listener type should have been checked
+ # during parsing
+ logger.warning("Unrecognized listener type: %s", listener.type)
# Gauges to expose monthly active user control metrics
@@ -377,13 +383,12 @@ def setup(config_options):
hs.setup_master()
- @defer.inlineCallbacks
- def do_acme():
+ async def do_acme() -> bool:
"""
Reprovision an ACME certificate, if it's required.
Returns:
- Deferred[bool]: Whether the cert has been updated.
+ Whether the cert has been updated.
"""
acme = hs.get_acme_handler()
@@ -402,30 +407,28 @@ def setup(config_options):
provision = True
if provision:
- yield acme.provision_certificate()
+ await acme.provision_certificate()
return provision
- @defer.inlineCallbacks
- def reprovision_acme():
+ async def reprovision_acme():
"""
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
- reprovisioned = yield do_acme()
+ reprovisioned = await do_acme()
if reprovisioned:
_base.refresh_certificate(hs)
- @defer.inlineCallbacks
- def start():
+ async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
- yield acme.start_listening()
- yield do_acme()
+ await acme.start_listening()
+ await do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -434,12 +437,12 @@ def setup(config_options):
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
- yield defer.ensureDeferred(oidc.load_metadata())
- yield defer.ensureDeferred(oidc.load_jwks())
+ await oidc.load_metadata()
+ await oidc.load_jwks()
_base.start(hs, config.listeners)
- hs.get_datastore().db.updates.start_doing_background_updates()
+ hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
@@ -451,7 +454,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
return hs
@@ -480,8 +483,7 @@ class SynapseService(service.Service):
_stats_process = []
-@defer.inlineCallbacks
-def phone_stats_home(hs, stats, stats_process=_stats_process):
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
@@ -519,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
- stats["total_users"] = yield hs.get_datastore().count_all_users()
+ stats["total_users"] = await hs.get_datastore().count_all_users()
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+ for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
- room_count = yield hs.get_datastore().get_room_count()
+ room_count = await hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
- r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
+ r30_results = await hs.get_datastore().count_r30_users()
+ for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@@ -550,12 +552,12 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
#
# This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
- yield hs.get_proxied_http_client().put_json(
+ await hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
|