diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index fb476ddaf5..895b38ae76 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -28,9 +28,11 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
+from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
@@ -48,7 +50,6 @@ def register_sighup(func, *args, **kwargs):
Args:
func (function): Function to be called when sent a SIGHUP signal.
- Will be called with a single default argument, the homeserver.
*args, **kwargs: args and kwargs to be passed to the target function.
"""
_sighup_callbacks.append((func, args, kwargs))
@@ -244,19 +245,26 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+ @wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
for i, args, kwargs in _sighup_callbacks:
- i(hs, *args, **kwargs)
+ i(*args, **kwargs)
sdnotify(b"READY=1")
- signal.signal(signal.SIGHUP, handle_sighup)
+ # We defer running the sighup handlers until next reactor tick. This
+ # is so that we're in a sane state, e.g. flushing the logs may fail
+ # if the sighup happens in the middle of writing a log entry.
+ def run_sighup(*args, **kwargs):
+ hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
- register_sighup(refresh_certificate)
+ signal.signal(signal.SIGHUP, run_sighup)
+
+ register_sighup(refresh_certificate, hs)
# Load the certificate from disk.
refresh_certificate(hs)
@@ -271,9 +279,19 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
+ # Log when we start the shut down process.
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", logger.info, "Shutting down..."
+ )
+
setup_sentry(hs)
setup_sdnotify(hs)
+ # If background tasks are running on the main process, start collecting the
+ # phone home stats.
+ if hs.config.run_background_tasks:
+ start_phone_stats_home(hs)
+
# We now freeze all allocated objects in the hopes that (almost)
# everything currently allocated are things that will be used for the
# rest of time. Doing so means less work each GC (hopefully).
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 7d309b1bb0..b4bd4d8e7a 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -89,7 +89,7 @@ async def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
- res = await hs.get_handlers().admin_handler.export_user_data(
+ res = await hs.get_admin_handler().export_user_data(
user_id, FileExfiltrationWriter(user_id, directory=directory)
)
print(res)
@@ -208,6 +208,7 @@ def start(config_options):
# Explicitly disable background processes
config.update_user_directory = False
+ config.run_background_tasks = False
config.start_pushers = False
config.send_federation = False
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c38413c893..1b511890aa 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -127,12 +127,16 @@ from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
+from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
+from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.databases.main.presence import UserPresenceState
from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@@ -454,6 +458,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
+ StatsStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
@@ -463,6 +468,7 @@ class GenericWorkerSlavedStore(
SlavedAccountDataStore,
SlavedPusherStore,
CensorEventsStore,
+ ClientIpWorkerStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
@@ -476,7 +482,9 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
+ ServerMetricsStore,
SearchWorkerStore,
+ TransactionWorkerStore,
BaseSlavedStore,
):
pass
@@ -782,10 +790,6 @@ class FederationSenderHandler:
send_queue.process_rows_for_federation(self.federation_sender, rows)
await self.update_token(token)
- # We also need to poke the federation sender when new events happen
- elif stream_name == "events":
- self.federation_sender.notify_new_events(token)
-
# ... and when new receipts happen
elif stream_name == ReceiptsStream.NAME:
await self._on_new_receipts(rows)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index dff739e106..2b5465417f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -17,14 +17,10 @@
import gc
import logging
-import math
import os
-import resource
import sys
from typing import Iterable
-from prometheus_client import Gauge
-
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
@@ -60,8 +56,6 @@ from synapse.http.server import (
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.module_api import ModuleApi
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -111,7 +105,7 @@ class SynapseHomeServer(HomeServer):
additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
- module_api = ModuleApi(self, self.get_auth_handler())
+ module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
@@ -334,20 +328,6 @@ class SynapseHomeServer(HomeServer):
logger.warning("Unrecognized listener type: %s", listener.type)
-# Gauges to expose monthly active user control metrics
-current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
-current_mau_by_service_gauge = Gauge(
- "synapse_admin_mau_current_mau_by_service",
- "Current MAU by service",
- ["app_service"],
-)
-max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
-registered_reserved_users_mau_gauge = Gauge(
- "synapse_admin_mau:registered_reserved_users",
- "Registered users with reserved threepids",
-)
-
-
def setup(config_options):
"""
Args:
@@ -389,8 +369,6 @@ def setup(config_options):
except UpgradeDatabaseException as e:
quit_with_error("Failed to upgrade database: %s" % (e,))
- hs.setup_master()
-
async def do_acme() -> bool:
"""
Reprovision an ACME certificate, if it's required.
@@ -486,92 +464,6 @@ class SynapseService(service.Service):
return self._port.stopListening()
-# Contains the list of processes we will be monitoring
-# currently either 0 or 1
-_stats_process = []
-
-
-async def phone_stats_home(hs, stats, stats_process=_stats_process):
- logger.info("Gathering stats for reporting")
- now = int(hs.get_clock().time())
- uptime = int(now - hs.start_time)
- if uptime < 0:
- uptime = 0
-
- #
- # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
- #
- old = stats_process[0]
- new = (now, resource.getrusage(resource.RUSAGE_SELF))
- stats_process[0] = new
-
- # Get RSS in bytes
- stats["memory_rss"] = new[1].ru_maxrss
-
- # Get CPU time in % of a single core, not % of all cores
- used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
- old[1].ru_utime + old[1].ru_stime
- )
- if used_cpu_time == 0 or new[0] == old[0]:
- stats["cpu_average"] = 0
- else:
- stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
-
- #
- # General statistics
- #
-
- stats["homeserver"] = hs.config.server_name
- stats["server_context"] = hs.config.server_context
- stats["timestamp"] = now
- stats["uptime_seconds"] = uptime
- version = sys.version_info
- stats["python_version"] = "{}.{}.{}".format(
- version.major, version.minor, version.micro
- )
- stats["total_users"] = await hs.get_datastore().count_all_users()
-
- total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
- stats["total_nonbridged_users"] = total_nonbridged_users
-
- daily_user_type_results = await hs.get_datastore().count_daily_user_type()
- for name, count in daily_user_type_results.items():
- stats["daily_user_type_" + name] = count
-
- room_count = await hs.get_datastore().get_room_count()
- stats["total_room_count"] = room_count
-
- stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
-
- r30_results = await hs.get_datastore().count_r30_users()
- for name, count in r30_results.items():
- stats["r30_users_" + name] = count
-
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = hs.config.caches.global_factor
- stats["event_cache_size"] = hs.config.caches.event_cache_size
-
- #
- # Database version
- #
-
- # This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
-
- logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
- try:
- await hs.get_proxied_http_client().put_json(
- hs.config.report_stats_endpoint, stats
- )
- except Exception as e:
- logger.warning("Error reporting stats: %s", e)
-
-
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
@@ -597,81 +489,6 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
- clock = hs.get_clock()
-
- stats = {}
-
- def performance_stats_init():
- _stats_process.clear()
- _stats_process.append(
- (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
- )
-
- def start_phone_stats_home():
- return run_as_background_process(
- "phone_stats_home", phone_stats_home, hs, stats
- )
-
- def generate_user_daily_visit_stats():
- return run_as_background_process(
- "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
- )
-
- # Rather than update on per session basis, batch up the requests.
- # If you increase the loop period, the accuracy of user_daily_visits
- # table will decrease
- clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
-
- # monthly active user limiting functionality
- def reap_monthly_active_users():
- return run_as_background_process(
- "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
- )
-
- clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
- reap_monthly_active_users()
-
- async def generate_monthly_active_users():
- current_mau_count = 0
- current_mau_count_by_service = {}
- reserved_users = ()
- store = hs.get_datastore()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- current_mau_count = await store.get_monthly_active_count()
- current_mau_count_by_service = (
- await store.get_monthly_active_count_by_service()
- )
- reserved_users = await store.get_registered_reserved_users()
- current_mau_gauge.set(float(current_mau_count))
-
- for app_service, count in current_mau_count_by_service.items():
- current_mau_by_service_gauge.labels(app_service).set(float(count))
-
- registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
- max_mau_gauge.set(float(hs.config.max_mau_value))
-
- def start_generate_monthly_active_users():
- return run_as_background_process(
- "generate_monthly_active_users", generate_monthly_active_users
- )
-
- start_generate_monthly_active_users()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
- # End of monthly active user settings
-
- if hs.config.report_stats:
- logger.info("Scheduling stats reporting for 3 hour intervals")
- clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
-
- # We need to defer this init for the cases that we daemonize
- # otherwise the process ID we get is that of the non-daemon process
- clock.call_later(0, performance_stats_init)
-
- # We wait 5 minutes to send the first set of stats as the server can
- # be quite busy the first few minutes
- clock.call_later(5 * 60, start_phone_stats_home)
-
_base.start_reactor(
"synapse-homeserver",
soft_file_limit=hs.config.soft_file_limit,
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
new file mode 100644
index 0000000000..c38cf8231f
--- /dev/null
+++ b/synapse/app/phone_stats_home.py
@@ -0,0 +1,190 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import math
+import resource
+import sys
+
+from prometheus_client import Gauge
+
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+
+logger = logging.getLogger("synapse.app.homeserver")
+
+# Contains the list of processes we will be monitoring
+# currently either 0 or 1
+_stats_process = []
+
+# Gauges to expose monthly active user control metrics
+current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
+current_mau_by_service_gauge = Gauge(
+ "synapse_admin_mau_current_mau_by_service",
+ "Current MAU by service",
+ ["app_service"],
+)
+max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
+registered_reserved_users_mau_gauge = Gauge(
+ "synapse_admin_mau:registered_reserved_users",
+ "Registered users with reserved threepids",
+)
+
+
+@wrap_as_background_process("phone_stats_home")
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
+ logger.info("Gathering stats for reporting")
+ now = int(hs.get_clock().time())
+ uptime = int(now - hs.start_time)
+ if uptime < 0:
+ uptime = 0
+
+ #
+ # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # General statistics
+ #
+
+ stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
+ stats["timestamp"] = now
+ stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
+ stats["total_users"] = await hs.get_datastore().count_all_users()
+
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+ for name, count in daily_user_type_results.items():
+ stats["daily_user_type_" + name] = count
+
+ room_count = await hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
+
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+
+ r30_results = await hs.get_datastore().count_r30_users()
+ for name, count in r30_results.items():
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = hs.config.caches.global_factor
+ stats["event_cache_size"] = hs.config.caches.event_cache_size
+
+ #
+ # Database version
+ #
+
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
+
+ #
+ # Logging configuration
+ #
+ synapse_logger = logging.getLogger("synapse")
+ log_level = synapse_logger.getEffectiveLevel()
+ stats["log_level"] = logging.getLevelName(log_level)
+
+ logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ try:
+ await hs.get_proxied_http_client().put_json(
+ hs.config.report_stats_endpoint, stats
+ )
+ except Exception as e:
+ logger.warning("Error reporting stats: %s", e)
+
+
+def start_phone_stats_home(hs):
+ """
+ Start the background tasks which report phone home stats.
+ """
+ clock = hs.get_clock()
+
+ stats = {}
+
+ def performance_stats_init():
+ _stats_process.clear()
+ _stats_process.append(
+ (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
+ )
+
+ # Rather than update on per session basis, batch up the requests.
+ # If you increase the loop period, the accuracy of user_daily_visits
+ # table will decrease
+ clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000)
+
+ # monthly active user limiting functionality
+ clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60)
+ hs.get_datastore().reap_monthly_active_users()
+
+ @wrap_as_background_process("generate_monthly_active_users")
+ async def generate_monthly_active_users():
+ current_mau_count = 0
+ current_mau_count_by_service = {}
+ reserved_users = ()
+ store = hs.get_datastore()
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ current_mau_count = await store.get_monthly_active_count()
+ current_mau_count_by_service = (
+ await store.get_monthly_active_count_by_service()
+ )
+ reserved_users = await store.get_registered_reserved_users()
+ current_mau_gauge.set(float(current_mau_count))
+
+ for app_service, count in current_mau_count_by_service.items():
+ current_mau_by_service_gauge.labels(app_service).set(float(count))
+
+ registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
+ max_mau_gauge.set(float(hs.config.max_mau_value))
+
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ generate_monthly_active_users()
+ clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+ # End of monthly active user settings
+
+ if hs.config.report_stats:
+ logger.info("Scheduling stats reporting for 3 hour intervals")
+ clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
+
+ # We need to defer this init for the cases that we daemonize
+ # otherwise the process ID we get is that of the non-daemon process
+ clock.call_later(0, performance_stats_init)
+
+ # We wait 5 minutes to send the first set of stats as the server can
+ # be quite busy the first few minutes
+ clock.call_later(5 * 60, phone_stats_home, hs, stats)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 13ec1f71a6..3944780a42 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,14 +14,15 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes
-from synapse.appservice.api import ApplicationServiceApi
-from synapse.types import GroupID, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.events import EventBase
+from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
+from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
+ from synapse.appservice.api import ApplicationServiceApi
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -32,38 +33,6 @@ class ApplicationServiceState:
UP = "up"
-class AppServiceTransaction:
- """Represents an application service transaction."""
-
- def __init__(self, service, id, events):
- self.service = service
- self.id = id
- self.events = events
-
- async def send(self, as_api: ApplicationServiceApi) -> bool:
- """Sends this transaction using the provided AS API interface.
-
- Args:
- as_api: The API to use to send.
- Returns:
- True if the transaction was sent.
- """
- return await as_api.push_bulk(
- service=self.service, events=self.events, txn_id=self.id
- )
-
- async def complete(self, store: "DataStore") -> None:
- """Completes this transaction as successful.
-
- Marks this transaction ID on the application service and removes the
- transaction contents from the database.
-
- Args:
- store: The database store to operate on.
- """
- await store.complete_appservice_txn(service=self.service, txn_id=self.id)
-
-
class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -83,14 +52,15 @@ class ApplicationService:
self,
token,
hostname,
+ id,
+ sender,
url=None,
namespaces=None,
hs_token=None,
- sender=None,
- id=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
+ supports_ephemeral=False,
):
self.token = token
self.url = (
@@ -102,6 +72,7 @@ class ApplicationService:
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.ip_range_whitelist = ip_range_whitelist
+ self.supports_ephemeral = supports_ephemeral
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@@ -161,19 +132,21 @@ class ApplicationService:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
- def _matches_regex(self, test_string, namespace_key):
+ def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
for regex_obj in self.namespaces[namespace_key]:
if regex_obj["regex"].match(test_string):
return regex_obj
return None
- def _is_exclusive(self, ns_key, test_string):
+ def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
return False
- async def _matches_user(self, event, store):
+ async def _matches_user(
+ self, event: Optional[EventBase], store: Optional["DataStore"] = None
+ ) -> bool:
if not event:
return False
@@ -188,11 +161,22 @@ class ApplicationService:
if not store:
return False
- does_match = await self._matches_user_in_member_list(event.room_id, store)
+ does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
@cached(num_args=1, cache_context=True)
- async def _matches_user_in_member_list(self, room_id, store, cache_context):
+ async def matches_user_in_member_list(
+ self, room_id: str, store: "DataStore", cache_context: _CacheContext,
+ ) -> bool:
+ """Check if this service is interested a room based upon it's membership
+
+ Args:
+ room_id: The room to check.
+ store: The datastore to query.
+
+ Returns:
+ True if this service would like to know about this room.
+ """
member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
@@ -203,12 +187,14 @@ class ApplicationService:
return True
return False
- def _matches_room_id(self, event):
+ def _matches_room_id(self, event: EventBase) -> bool:
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
- async def _matches_aliases(self, event, store):
+ async def _matches_aliases(
+ self, event: EventBase, store: Optional["DataStore"] = None
+ ) -> bool:
if not store or not event:
return False
@@ -218,12 +204,15 @@ class ApplicationService:
return True
return False
- async def is_interested(self, event, store=None) -> bool:
+ async def is_interested(
+ self, event: EventBase, store: Optional["DataStore"] = None
+ ) -> bool:
"""Check if this service is interested in this event.
Args:
- event(Event): The event to check.
- store(DataStore)
+ event: The event to check.
+ store: The datastore to query.
+
Returns:
True if this service would like to know about this event.
"""
@@ -231,39 +220,66 @@ class ApplicationService:
if self._matches_room_id(event):
return True
+ # This will check the namespaces first before
+ # checking the store, so should be run before _matches_aliases
+ if await self._matches_user(event, store):
+ return True
+
+ # This will check the store, so should be run last
if await self._matches_aliases(event, store):
return True
- if await self._matches_user(event, store):
+ return False
+
+ @cached(num_args=1)
+ async def is_interested_in_presence(
+ self, user_id: UserID, store: "DataStore"
+ ) -> bool:
+ """Check if this service is interested a user's presence
+
+ Args:
+ user_id: The user to check.
+ store: The datastore to query.
+
+ Returns:
+ True if this service would like to know about presence for this user.
+ """
+ # Find all the rooms the sender is in
+ if self.is_interested_in_user(user_id.to_string()):
return True
+ room_ids = await store.get_rooms_for_user(user_id.to_string())
+ # Then find out if the appservice is interested in any of those rooms
+ for room_id in room_ids:
+ if await self.matches_user_in_member_list(room_id, store):
+ return True
return False
- def is_interested_in_user(self, user_id):
+ def is_interested_in_user(self, user_id: str) -> bool:
return (
- self._matches_regex(user_id, ApplicationService.NS_USERS)
+ bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
or user_id == self.sender
)
- def is_interested_in_alias(self, alias):
+ def is_interested_in_alias(self, alias: str) -> bool:
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
- def is_interested_in_room(self, room_id):
+ def is_interested_in_room(self, room_id: str) -> bool:
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
- def is_exclusive_user(self, user_id):
+ def is_exclusive_user(self, user_id: str) -> bool:
return (
self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender
)
- def is_interested_in_protocol(self, protocol):
+ def is_interested_in_protocol(self, protocol: str) -> bool:
return protocol in self.protocols
- def is_exclusive_alias(self, alias):
+ def is_exclusive_alias(self, alias: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
- def is_exclusive_room(self, room_id):
+ def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exclusive_user_regexes(self):
@@ -276,14 +292,14 @@ class ApplicationService:
if regex_obj["exclusive"]
]
- def get_groups_for_user(self, user_id):
+ def get_groups_for_user(self, user_id: str) -> Iterable[str]:
"""Get the groups that this user is associated with by this AS
Args:
- user_id (str): The ID of the user.
+ user_id: The ID of the user.
Returns:
- iterable[str]: an iterable that yields group_id strings.
+ An iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
@@ -291,7 +307,7 @@ class ApplicationService:
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
)
- def is_rate_limited(self):
+ def is_rate_limited(self) -> bool:
return self.rate_limited
def __str__(self):
@@ -300,3 +316,45 @@ class ApplicationService:
dict_copy["token"] = "<redacted>"
dict_copy["hs_token"] = "<redacted>"
return "ApplicationService: %s" % (dict_copy,)
+
+
+class AppServiceTransaction:
+ """Represents an application service transaction."""
+
+ def __init__(
+ self,
+ service: ApplicationService,
+ id: int,
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ ):
+ self.service = service
+ self.id = id
+ self.events = events
+ self.ephemeral = ephemeral
+
+ async def send(self, as_api: "ApplicationServiceApi") -> bool:
+ """Sends this transaction using the provided AS API interface.
+
+ Args:
+ as_api: The API to use to send.
+ Returns:
+ True if the transaction was sent.
+ """
+ return await as_api.push_bulk(
+ service=self.service,
+ events=self.events,
+ ephemeral=self.ephemeral,
+ txn_id=self.id,
+ )
+
+ async def complete(self, store: "DataStore") -> None:
+ """Completes this transaction as successful.
+
+ Marks this transaction ID on the application service and removes the
+ transaction contents from the database.
+
+ Args:
+ store: The database store to operate on.
+ """
+ await store.complete_appservice_txn(service=self.service, txn_id=self.id)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index c526c28b93..e366a982b8 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,12 +14,13 @@
# limitations under the License.
import logging
import urllib
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
+from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
@@ -93,7 +94,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
- )
+ ) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
if service.url is None:
@@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
- async def push_bulk(self, service, events, txn_id=None):
+ async def push_bulk(
+ self,
+ service: "ApplicationService",
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ txn_id: Optional[int] = None,
+ ):
if service.url is None:
return True
@@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
- txn_id = str(0)
- txn_id = str(txn_id)
+ txn_id = 0
+
+ uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
+
+ # Never send ephemeral events to appservices that do not support it
+ if service.supports_ephemeral:
+ body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
+ else:
+ body = {"events": events}
- uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
await self.put_json(
- uri=uri,
- json_body={"events": events},
- args={"access_token": service.hs_token},
+ uri=uri, json_body=body, args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 8eb8c6f51c..58291afc22 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -49,14 +49,24 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
+from typing import List
-from synapse.appservice import ApplicationServiceState
+from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+# Maximum number of events to provide in an AS transaction.
+MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
+
+# Maximum number of ephemeral events to provide in an AS transaction.
+MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
+
+
class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
@@ -82,8 +92,13 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
- def submit_event_for_as(self, service, event):
- self.queuer.enqueue(service, event)
+ def submit_event_for_as(self, service: ApplicationService, event: EventBase):
+ self.queuer.enqueue_event(service, event)
+
+ def submit_ephemeral_events_for_as(
+ self, service: ApplicationService, events: List[JsonDict]
+ ):
+ self.queuer.enqueue_ephemeral(service, events)
class _ServiceQueuer:
@@ -96,17 +111,15 @@ class _ServiceQueuer:
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
+ self.queued_ephemeral = {} # dict of {service_id: [events]}
# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
- def enqueue(self, service, event):
- self.queued_events.setdefault(service.id, []).append(event)
-
+ def _start_background_request(self, service):
# start a sender for this appservice if we don't already have one
-
if service.id in self.requests_in_flight:
return
@@ -114,7 +127,15 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service
)
- async def _send_request(self, service):
+ def enqueue_event(self, service: ApplicationService, event: EventBase):
+ self.queued_events.setdefault(service.id, []).append(event)
+ self._start_background_request(service)
+
+ def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
+ self.queued_ephemeral.setdefault(service.id, []).extend(events)
+ self._start_background_request(service)
+
+ async def _send_request(self, service: ApplicationService):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@@ -122,11 +143,19 @@ class _ServiceQueuer:
self.requests_in_flight.add(service.id)
try:
while True:
- events = self.queued_events.pop(service.id, [])
- if not events:
+ all_events = self.queued_events.get(service.id, [])
+ events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+ del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+
+ all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
+ ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+ del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+
+ if not events and not ephemeral:
return
+
try:
- await self.txn_ctrl.send(service, events)
+ await self.txn_ctrl.send(service, events, ephemeral)
except Exception:
logger.exception("AS request failed")
finally:
@@ -158,9 +187,16 @@ class _TransactionController:
# for UTs
self.RECOVERER_CLASS = _Recoverer
- async def send(self, service, events):
+ async def send(
+ self,
+ service: ApplicationService,
+ events: List[EventBase],
+ ephemeral: List[JsonDict] = [],
+ ):
try:
- txn = await self.store.create_appservice_txn(service=service, events=events)
+ txn = await self.store.create_appservice_txn(
+ service=service, events=events, ephemeral=ephemeral
+ )
service_is_up = await self._is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
@@ -204,7 +240,7 @@ class _TransactionController:
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
- async def _is_service_up(self, service):
+ async def _is_service_up(self, service: ApplicationService) -> bool:
state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None
|