diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index dedff81af3..fb476ddaf5 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import logging
import os
@@ -20,8 +19,8 @@ import signal
import socket
import sys
import traceback
+from typing import Iterable
-from daemonize import Daemonize
from typing_extensions import NoReturn
from twisted.internet import defer, error, reactor
@@ -29,9 +28,11 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
+from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer
+from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -127,17 +128,8 @@ def start_reactor(
if print_pidfile:
print(pid_file)
- daemon = Daemonize(
- app=appname,
- pid=pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ daemonize_process(pid_file, logger)
+ run()
def quit_with_error(error_string: str) -> NoReturn:
@@ -234,7 +226,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-def start(hs, listeners=None):
+def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
"""
Start a Synapse server or worker.
@@ -245,8 +237,8 @@ def start(hs, listeners=None):
notify systemd.
Args:
- hs (synapse.server.HomeServer)
- listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
+ hs: homeserver instance
+ listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
try:
# Set up the SIGHUP machinery.
@@ -276,7 +268,7 @@ def start(hs, listeners=None):
# It is now safe to start your Synapse.
hs.start_listening(listeners)
- hs.get_datastore().db.start_profiling()
+ hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
setup_sentry(hs)
@@ -342,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
can run out of file descriptors and infinite loop if we attempt to do too
many DNS queries at once
+
+ XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
+ you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
+ backed by the reactor's default threadpool (which is limited to 10 threads). So
+ (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
+ understand why we would run out of FDs if we did too many lookups at once.
+ -- richvdh 2020/08/29
"""
new_resolver = _LimitedHostnameResolver(
reactor.nameResolver, max_dns_requests_in_flight
@@ -350,7 +349,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
reactor.installNameResolver(new_resolver)
-class _LimitedHostnameResolver(object):
+class _LimitedHostnameResolver:
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
"""
@@ -410,7 +409,7 @@ class _LimitedHostnameResolver(object):
yield deferred
-class _DeferredResolutionReceiver(object):
+class _DeferredResolutionReceiver:
"""Wraps a IResolutionReceiver and simply resolves the given deferred when
resolution is complete
"""
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index a37818fe9a..b6c9085670 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer):
pass
-@defer.inlineCallbacks
-def export_data_command(hs, args):
+async def export_data_command(hs, args):
"""Export data for a user.
Args:
@@ -91,10 +90,8 @@ def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
- res = yield defer.ensureDeferred(
- hs.get_handlers().admin_handler.export_user_data(
- user_id, FileExfiltrationWriter(user_id, directory=directory)
- )
+ res = await hs.get_handlers().admin_handler.export_user_data(
+ user_id, FileExfiltrationWriter(user_id, directory=directory)
)
print(res)
@@ -232,14 +229,15 @@ def start(config_options):
# We also make sure that `_base.start` gets run before we actually run the
# command.
- @defer.inlineCallbacks
- def run(_reactor):
+ async def run():
with LoggingContext("command"):
- yield _base.start(ss, [])
- yield args.func(ss, args)
+ _base.start(ss, [])
+ await args.func(ss, args)
_base.start_worker_reactor(
- "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+ "synapse-admin-cmd",
+ config,
+ run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f3ec2a34ec..f985810e88 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
-from twisted.internet import defer, reactor
+from twisted.internet import address, reactor
import synapse
import synapse.events
@@ -37,6 +37,7 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.config.server import ListenerConfig
from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.presence import (
@@ -86,7 +87,6 @@ from synapse.replication.tcp.streams import (
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
- TypingStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
@@ -110,6 +110,7 @@ from synapse.rest.client.v1.room import (
RoomSendEventRestServlet,
RoomStateEventRestServlet,
RoomStateRestServlet,
+ RoomTypingRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
@@ -122,17 +123,18 @@ from synapse.rest.client.v2_alpha.account_data import (
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
+from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.censor_events import CensorEventsStore
-from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
-from synapse.storage.data_stores.main.monthly_active_users import (
+from synapse.server import HomeServer, cache_in_self
+from synapse.storage.databases.main.censor_events import CensorEventsStore
+from synapse.storage.databases.main.media_repository import MediaRepositoryStore
+from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
-from synapse.storage.data_stores.main.presence import UserPresenceState
-from synapse.storage.data_stores.main.search import SearchWorkerStore
-from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.storage.databases.main.presence import UserPresenceState
+from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
+from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
@@ -205,10 +207,30 @@ class KeyUploadServlet(RestServlet):
if body:
# They're actually trying to upload something, proxy to main synapse.
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
- headers = {"Authorization": auth_headers}
+
+ # Proxy headers from the original request, such as the auth headers
+ # (in case the access token is there) and the original IP /
+ # User-Agent of the request.
+ headers = {
+ header: request.requestHeaders.getRawHeaders(header, [])
+ for header in (b"Authorization", b"User-Agent")
+ }
+ # Add the previous hop the the X-Forwarded-For header.
+ x_forwarded_for = request.requestHeaders.getRawHeaders(
+ b"X-Forwarded-For", []
+ )
+ if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
+ previous_host = request.client.host.encode("ascii")
+ # If the header exists, add to the comma-separated list of the first
+ # instance of the header. Otherwise, generate a new header.
+ if x_forwarded_for:
+ x_forwarded_for = [
+ x_forwarded_for[0] + b", " + previous_host
+ ] + x_forwarded_for[1:]
+ else:
+ x_forwarded_for = [previous_host]
+ headers[b"X-Forwarded-For"] = x_forwarded_for
+
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
@@ -353,9 +375,8 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing()
- @defer.inlineCallbacks
- def notify_from_replication(self, states, stream_id):
- parties = yield get_interested_parties(self.store, states)
+ async def notify_from_replication(self, states, stream_id):
+ parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -365,8 +386,7 @@ class GenericWorkerPresence(BasePresenceHandler):
users=users_to_states.keys(),
)
- @defer.inlineCallbacks
- def process_replication_rows(self, token, rows):
+ async def process_replication_rows(self, token, rows):
states = [
UserPresenceState(
row.user_id,
@@ -384,7 +404,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.user_to_current_state[state.user_id] = state
stream_id = token
- yield self.notify_from_replication(states, stream_id)
+ await self.notify_from_replication(states, stream_id)
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
return [
@@ -430,37 +450,6 @@ class GenericWorkerPresence(BasePresenceHandler):
await self._bump_active_client(user_id=user_id)
-class GenericWorkerTyping(object):
- def __init__(self, hs):
- self._latest_room_serial = 0
- self._reset()
-
- def _reset(self):
- """
- Reset the typing handler's data caches.
- """
- # map room IDs to serial numbers
- self._room_serials = {}
- # map room IDs to sets of users currently typing
- self._room_typing = {}
-
- def process_replication_rows(self, token, rows):
- if self._latest_room_serial > token:
- # The master has gone backwards. To prevent inconsistent data, just
- # clear everything.
- self._reset()
-
- # Set the latest serial token to whatever the server gave us.
- self._latest_room_serial = token
-
- for row in rows:
- self._room_serials[row.room_id] = token
- self._room_typing[row.room_id] = row.user_ids
-
- def get_current_token(self) -> int:
- return self._latest_room_serial
-
-
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
@@ -490,37 +479,27 @@ class GenericWorkerSlavedStore(
SearchWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database, db_conn, hs):
- super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs)
+ pass
- # We pull out the current federation stream position now so that we
- # always have a known value for the federation position in memory so
- # that we don't have to bounce via a deferred once when we start the
- # replication streams.
- self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
- def _get_federation_out_pos(self, db_conn):
- sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
- sql = self.database_engine.convert_param_style(sql)
+class GenericWorkerServer(HomeServer):
+ DATASTORE_CLASS = GenericWorkerSlavedStore
- txn = db_conn.cursor()
- txn.execute(sql, ("federation",))
- rows = txn.fetchall()
- txn.close()
+ def _listen_http(self, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
- return rows[0][0] if rows else -1
+ assert listener_config.http_options is not None
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
-class GenericWorkerServer(HomeServer):
- DATASTORE_CLASS = GenericWorkerSlavedStore
+ # We always include a health resource.
+ resources = {"/health": HealthResource()}
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
+ for res in listener_config.http_options.resources:
+ for name in res.names:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
@@ -550,6 +529,7 @@ class GenericWorkerServer(HomeServer):
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
+ RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
@@ -590,7 +570,7 @@ class GenericWorkerServer(HomeServer):
" repository is disabled. Ignoring."
)
- if name == "openid" and "federation" not in res["names"]:
+ if name == "openid" and "federation" not in res.names:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
@@ -625,19 +605,19 @@ class GenericWorkerServer(HomeServer):
logger.info("Synapse worker now listening on port %d", port)
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listen_http(listener)
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
_base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -646,31 +626,29 @@ class GenericWorkerServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unsupported listener type: %s", listener.type)
self.get_tcp_replication().start_replication(self)
- def remove_pusher(self, app_id, push_key, user_id):
+ async def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
- def build_replication_data_handler(self):
+ @cache_in_self
+ def get_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
- def build_presence_handler(self):
+ @cache_in_self
+ def get_presence_handler(self):
return GenericWorkerPresence(self)
- def build_typing_handler(self):
- return GenericWorkerTyping(self)
-
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs)
self.store = hs.get_datastore()
- self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
self.notifier = hs.get_notifier()
@@ -707,11 +685,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
- elif stream_name == TypingStream.NAME:
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows]
- )
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
@@ -738,6 +711,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
except Exception:
logger.exception("Error processing replication")
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ await super().on_position(stream_name, instance_name, token)
+ # Also call on_rdata to ensure that stream positions are properly reset.
+ await self.on_rdata(stream_name, instance_name, token, [])
+
def stop_pusher(self, user_id, app_id, pushkey):
if not self.notify_pushers:
return
@@ -767,7 +745,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
self.send_handler.wake_destination(server)
-class FederationSenderHandler(object):
+class FederationSenderHandler:
"""Processes the fedration replication stream
This class is only instantiate on the worker responsible for sending outbound
@@ -781,19 +759,11 @@ class FederationSenderHandler(object):
self.federation_sender = hs.get_federation_sender()
self._hs = hs
- # if the worker is restarted, we want to pick up where we left off in
- # the replication stream, so load the position from the database.
- #
- # XXX is this actually worthwhile? Whenever the master is restarted, we'll
- # drop some rows anyway (which is mostly fine because we're only dropping
- # typing and presence notifications). If the replication stream is
- # unreliable, why do we do all this hoop-jumping to store the position in the
- # database? See also https://github.com/matrix-org/synapse/issues/7535.
- #
- self.federation_position = self.store.federation_out_pos_startup
+ # Stores the latest position in the federation stream we've gotten up
+ # to. This is always set before we use it.
+ self.federation_position = None
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- self._last_ack = self.federation_position
def on_start(self):
# There may be some events that are persisted but haven't been sent,
@@ -901,7 +871,6 @@ class FederationSenderHandler(object):
# We ACK this token over replication so that the master can drop
# its in memory queues
self._hs.get_tcp_replication().send_federation_ack(current_position)
- self._last_ack = current_position
except Exception:
logger.exception("Error updating federation stream position")
@@ -929,7 +898,7 @@ def start(config_options):
)
if config.worker_app == "synapse.app.appservice":
- if config.notify_appservices:
+ if config.appservice.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -939,13 +908,13 @@ def start(config_options):
sys.exit(1)
# Force the appservice to start since they will be disabled in the main config
- config.notify_appservices = True
+ config.appservice.notify_appservices = True
else:
# For other worker types we force this to off.
- config.notify_appservices = False
+ config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
- if config.start_pushers:
+ if config.server.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -955,13 +924,13 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.start_pushers = True
+ config.server.start_pushers = True
else:
# For other worker types we force this to off.
- config.start_pushers = False
+ config.server.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
- if config.update_user_directory:
+ if config.server.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -971,13 +940,13 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.update_user_directory = True
+ config.server.update_user_directory = True
else:
# For other worker types we force this to off.
- config.update_user_directory = False
+ config.server.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
- if config.send_federation:
+ if config.worker.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -987,10 +956,10 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.send_federation = True
+ config.worker.send_federation = True
else:
# For other worker types we force this to off.
- config.send_federation = False
+ config.worker.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8454d74858..6014adc850 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -23,8 +23,7 @@ import math
import os
import resource
import sys
-
-from six import iteritems
+from typing import Iterable
from prometheus_client import Gauge
@@ -50,12 +49,14 @@ from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import (
OptionsResource,
RootOptionsRedirectResource,
RootRedirect,
+ StaticResource,
)
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
@@ -67,6 +68,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
+from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
@@ -89,24 +91,26 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
- def _listener_http(self, config, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- tls = listener_config.get("tls", False)
- site_tag = listener_config.get("tag", port)
+ def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
+ tls = listener_config.tls
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "openid" and "federation" in res["names"]:
+ # We always include a health resource.
+ resources = {"/health": HealthResource()}
+
+ for res in listener_config.http_options.resources:
+ for name in res.names:
+ if name == "openid" and "federation" in res.names:
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
- resources.update(
- self._configure_named_resource(name, res.get("compress", False))
- )
+ resources.update(self._configure_named_resource(name, res.compress))
- additional_resources = listener_config.get("additional_resources", {})
+ additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
@@ -228,7 +232,7 @@ class SynapseHomeServer(HomeServer):
if name in ["static", "client"]:
resources.update(
{
- STATIC_PREFIX: File(
+ STATIC_PREFIX: StaticResource(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
}
@@ -278,7 +282,7 @@ class SynapseHomeServer(HomeServer):
return resources
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
config = self.get_config()
if config.redis_enabled:
@@ -288,25 +292,25 @@ class SynapseHomeServer(HomeServer):
self.get_tcp_replication().start_replication(self)
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listening_services.extend(self._listener_http(config, listener))
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "replication":
+ elif listener.type == "replication":
services = listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
ReplicationStreamProtocolFactory(self),
)
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -315,9 +319,11 @@ class SynapseHomeServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ # this shouldn't happen, as the listener type should have been checked
+ # during parsing
+ logger.warning("Unrecognized listener type: %s", listener.type)
# Gauges to expose monthly active user control metrics
@@ -377,13 +383,12 @@ def setup(config_options):
hs.setup_master()
- @defer.inlineCallbacks
- def do_acme():
+ async def do_acme() -> bool:
"""
Reprovision an ACME certificate, if it's required.
Returns:
- Deferred[bool]: Whether the cert has been updated.
+ Whether the cert has been updated.
"""
acme = hs.get_acme_handler()
@@ -402,30 +407,28 @@ def setup(config_options):
provision = True
if provision:
- yield acme.provision_certificate()
+ await acme.provision_certificate()
return provision
- @defer.inlineCallbacks
- def reprovision_acme():
+ async def reprovision_acme():
"""
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
- reprovisioned = yield do_acme()
+ reprovisioned = await do_acme()
if reprovisioned:
_base.refresh_certificate(hs)
- @defer.inlineCallbacks
- def start():
+ async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
- yield acme.start_listening()
- yield do_acme()
+ await acme.start_listening()
+ await do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -434,12 +437,12 @@ def setup(config_options):
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
- yield defer.ensureDeferred(oidc.load_metadata())
- yield defer.ensureDeferred(oidc.load_jwks())
+ await oidc.load_metadata()
+ await oidc.load_jwks()
_base.start(hs, config.listeners)
- hs.get_datastore().db.updates.start_doing_background_updates()
+ hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
@@ -451,7 +454,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
return hs
@@ -480,8 +483,7 @@ class SynapseService(service.Service):
_stats_process = []
-@defer.inlineCallbacks
-def phone_stats_home(hs, stats, stats_process=_stats_process):
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
@@ -519,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
- stats["total_users"] = yield hs.get_datastore().count_all_users()
+ stats["total_users"] = await hs.get_datastore().count_all_users()
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+ for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
- room_count = yield hs.get_datastore().get_room_count()
+ room_count = await hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
- r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
+ r30_results = await hs.get_datastore().count_r30_users()
+ for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@@ -550,12 +552,12 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
#
# This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
- yield hs.get_proxied_http_client().put_json(
+ await hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 1b13e84425..13ec1f71a6 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,24 +14,25 @@
# limitations under the License.
import logging
import re
-
-from six import string_types
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING
from synapse.api.constants import EventTypes
+from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
+
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-class ApplicationServiceState(object):
+class ApplicationServiceState:
DOWN = "down"
UP = "up"
-class AppServiceTransaction(object):
+class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(self, service, id, events):
@@ -39,19 +40,19 @@ class AppServiceTransaction(object):
self.id = id
self.events = events
- def send(self, as_api):
+ async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
- as_api(ApplicationServiceApi): The API to use to send.
+ as_api: The API to use to send.
Returns:
- A Deferred which resolves to True if the transaction was sent.
+ True if the transaction was sent.
"""
- return as_api.push_bulk(
+ return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id
)
- def complete(self, store):
+ async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
@@ -59,13 +60,11 @@ class AppServiceTransaction(object):
Args:
store: The database store to operate on.
- Returns:
- A Deferred which resolves to True if the transaction was completed.
"""
- return store.complete_appservice_txn(service=self.service, txn_id=self.id)
+ await store.complete_appservice_txn(service=self.service, txn_id=self.id)
-class ApplicationService(object):
+class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -156,7 +155,7 @@ class ApplicationService(object):
)
regex = regex_obj.get("regex")
- if isinstance(regex, string_types):
+ if isinstance(regex, str):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
@@ -174,8 +173,7 @@ class ApplicationService(object):
return regex_obj["exclusive"]
return False
- @defer.inlineCallbacks
- def _matches_user(self, event, store):
+ async def _matches_user(self, event, store):
if not event:
return False
@@ -190,12 +188,12 @@ class ApplicationService(object):
if not store:
return False
- does_match = yield self._matches_user_in_member_list(event.room_id, store)
+ does_match = await self._matches_user_in_member_list(event.room_id, store)
return does_match
- @cachedInlineCallbacks(num_args=1, cache_context=True)
- def _matches_user_in_member_list(self, room_id, store, cache_context):
- member_list = yield store.get_users_in_room(
+ @cached(num_args=1, cache_context=True)
+ async def _matches_user_in_member_list(self, room_id, store, cache_context):
+ member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
@@ -210,35 +208,33 @@ class ApplicationService(object):
return self.is_interested_in_room(event.room_id)
return False
- @defer.inlineCallbacks
- def _matches_aliases(self, event, store):
+ async def _matches_aliases(self, event, store):
if not store or not event:
return False
- alias_list = yield store.get_aliases_for_room(event.room_id)
+ alias_list = await store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
return True
return False
- @defer.inlineCallbacks
- def is_interested(self, event, store=None):
+ async def is_interested(self, event, store=None) -> bool:
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
store(DataStore)
Returns:
- bool: True if this service would like to know about this event.
+ True if this service would like to know about this event.
"""
# Do cheap checks first
if self._matches_room_id(event):
return True
- if (yield self._matches_aliases(event, store)):
+ if await self._matches_aliases(event, store):
return True
- if (yield self._matches_user(event, store)):
+ if await self._matches_user(event, store):
return True
return False
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 57174da021..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -13,20 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from six.moves import urllib
+import urllib
+from typing import TYPE_CHECKING, Optional
from prometheus_client import Counter
-from twisted.internet import defer
-
-from synapse.api.constants import ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.appservice import ApplicationService
+
logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
@@ -94,14 +95,12 @@ class ApplicationServiceApi(SimpleHttpClient):
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
- @defer.inlineCallbacks
- def query_user(self, service, user_id):
+ async def query_user(self, service, user_id):
if service.url is None:
return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
- response = None
try:
- response = yield self.get_json(uri, {"access_token": service.hs_token})
+ response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
@@ -112,14 +111,12 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_user to %s threw exception %s", uri, ex)
return False
- @defer.inlineCallbacks
- def query_alias(self, service, alias):
+ async def query_alias(self, service, alias):
if service.url is None:
return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
- response = None
try:
- response = yield self.get_json(uri, {"access_token": service.hs_token})
+ response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
@@ -130,8 +127,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex)
return False
- @defer.inlineCallbacks
- def query_3pe(self, service, kind, protocol, fields):
+ async def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
@@ -148,7 +144,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- response = yield self.get_json(uri, fields)
+ response = await self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response
@@ -169,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex)
return []
- def get_3pe_protocol(self, service, protocol):
+ async def get_3pe_protocol(
+ self, service: "ApplicationService", protocol: str
+ ) -> Optional[JsonDict]:
if service.url is None:
return {}
- @defer.inlineCallbacks
- def _get():
+ async def _get() -> Optional[JsonDict]:
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol),
)
try:
- info = yield self.get_json(uri, {})
+ info = await self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning(
@@ -202,14 +199,13 @@ class ApplicationServiceApi(SimpleHttpClient):
return None
key = (service.id, protocol)
- return self.protocol_meta_cache.wrap(key, _get)
+ return await self.protocol_meta_cache.wrap(key, _get)
- @defer.inlineCallbacks
- def push_bulk(self, service, events, txn_id=None):
+ async def push_bulk(self, service, events, txn_id=None):
if service.url is None:
return True
- events = self._serialize(events)
+ events = self._serialize(service, events)
if txn_id is None:
logger.warning(
@@ -220,7 +216,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
- yield self.put_json(
+ await self.put_json(
uri=uri,
json_body={"events": events},
args={"access_token": service.hs_token},
@@ -235,6 +231,18 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc()
return False
- def _serialize(self, events):
+ def _serialize(self, service, events):
time_now = self.clock.time_msec()
- return [serialize_event(e, time_now, as_client_event=True) for e in events]
+ return [
+ serialize_event(
+ e,
+ time_now,
+ as_client_event=True,
+ is_invite=(
+ e.type == EventTypes.Member
+ and e.membership == "invite"
+ and service.is_interested_in_user(e.state_key)
+ ),
+ )
+ for e in events
+ ]
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 9998f822f1..8eb8c6f51c 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -50,8 +50,6 @@ components.
"""
import logging
-from twisted.internet import defer
-
from synapse.appservice import ApplicationServiceState
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -59,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
logger = logging.getLogger(__name__)
-class ApplicationServiceScheduler(object):
+class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
@@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
- @defer.inlineCallbacks
- def start(self):
+ async def start(self):
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
- services = yield self.store.get_appservices_by_state(
+ services = await self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
@@ -89,7 +86,7 @@ class ApplicationServiceScheduler(object):
self.queuer.enqueue(service, event)
-class _ServiceQueuer(object):
+class _ServiceQueuer:
"""Queue of events waiting to be sent to appservices.
Groups events into transactions per-appservice, and sends them on to the
@@ -117,8 +114,7 @@ class _ServiceQueuer(object):
"as-sender-%s" % (service.id,), self._send_request, service
)
- @defer.inlineCallbacks
- def _send_request(self, service):
+ async def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@@ -130,14 +126,14 @@ class _ServiceQueuer(object):
if not events:
return
try:
- yield self.txn_ctrl.send(service, events)
+ await self.txn_ctrl.send(service, events)
except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
-class _TransactionController(object):
+class _TransactionController:
"""Transaction manager.
Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer
@@ -162,36 +158,33 @@ class _TransactionController(object):
# for UTs
self.RECOVERER_CLASS = _Recoverer
- @defer.inlineCallbacks
- def send(self, service, events):
+ async def send(self, service, events):
try:
- txn = yield self.store.create_appservice_txn(service=service, events=events)
- service_is_up = yield self._is_service_up(service)
+ txn = await self.store.create_appservice_txn(service=service, events=events)
+ service_is_up = await self._is_service_up(service)
if service_is_up:
- sent = yield txn.send(self.as_api)
+ sent = await txn.send(self.as_api)
if sent:
- yield txn.complete(self.store)
+ await txn.complete(self.store)
else:
run_in_background(self._on_txn_fail, service)
except Exception:
logger.exception("Error creating appservice transaction")
run_in_background(self._on_txn_fail, service)
- @defer.inlineCallbacks
- def on_recovered(self, recoverer):
+ async def on_recovered(self, recoverer):
logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id
)
self.recoverers.pop(recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
- yield self.store.set_appservice_state(
+ await self.store.set_appservice_state(
recoverer.service, ApplicationServiceState.UP
)
- @defer.inlineCallbacks
- def _on_txn_fail(self, service):
+ async def _on_txn_fail(self, service):
try:
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
self.start_recoverer(service)
except Exception:
logger.exception("Error starting AS recoverer")
@@ -211,13 +204,12 @@ class _TransactionController(object):
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
- @defer.inlineCallbacks
- def _is_service_up(self, service):
- state = yield self.store.get_appservice_state(service)
+ async def _is_service_up(self, service):
+ state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None
-class _Recoverer(object):
+class _Recoverer:
"""Manages retries and backoff for a DOWN appservice.
We have one of these for each appservice which is currently considered DOWN.
@@ -254,25 +246,24 @@ class _Recoverer(object):
self.backoff_counter += 1
self.recover()
- @defer.inlineCallbacks
- def retry(self):
+ async def retry(self):
logger.info("Starting retries on %s", self.service.id)
try:
while True:
- txn = yield self.store.get_oldest_unsent_txn(self.service)
+ txn = await self.store.get_oldest_unsent_txn(self.service)
if not txn:
# nothing left: we're done!
- self.callback(self)
+ await self.callback(self)
return
logger.info(
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id
)
- sent = yield txn.send(self.as_api)
+ sent = await txn.send(self.as_api)
if not sent:
break
- yield txn.complete(self.store)
+ await txn.complete(self.store)
# reset the backoff counter and then process the next transaction
self.backoff_counter = 1
|