diff --git a/synapse/__init__.py b/synapse/__init__.py
index 26bdfec33a..5e65033061 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -21,8 +21,8 @@ import os
import sys
# Check that we're not running on an unsupported Python version.
-if sys.version_info < (3, 6):
- print("Synapse requires Python 3.6 or above.")
+if sys.version_info < (3, 7):
+ print("Synapse requires Python 3.7 or above.")
sys.exit(1)
# Twisted and canonicaljson will fail to import when this file is executed to
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.51.0"
+__version__ = "1.52.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index f9f9467dc1..bd49fa6a5f 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -28,7 +28,6 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
-WEB_CLIENT_PREFIX = "/_matrix/client"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_R0_PREFIX = "/_matrix/media/r0"
MEDIA_V3_PREFIX = "/_matrix/media/v3"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 579adbbca0..bbab8a052a 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -16,7 +16,6 @@ import atexit
import gc
import logging
import os
-import platform
import signal
import socket
import sys
@@ -436,7 +435,8 @@ async def start(hs: "HomeServer") -> None:
# before we start the listeners.
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
+ m = module(config=config, api=module_api)
+ logger.info("Loaded module %s", m)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
@@ -468,15 +468,13 @@ async def start(hs: "HomeServer") -> None:
# everything currently allocated are things that will be used for the
# rest of time. Doing so means less work each GC (hopefully).
#
- # This only works on Python 3.7
- if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+ # PyPy does not (yet?) implement gc.freeze()
+ if hasattr(gc, "freeze"):
gc.collect()
gc.freeze()
- # Speed up shutdowns by freezing all allocated objects. This moves everything
- # into the permanent generation and excludes them from the final GC.
- # Unfortunately only works on Python 3.7
- if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+ # Speed up shutdowns by freezing all allocated objects. This moves everything
+ # into the permanent generation and excludes them from the final GC.
atexit.register(gc.freeze)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index efedcc8889..24d55b0494 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -21,7 +21,6 @@ from typing import Dict, Iterable, Iterator, List
from twisted.internet.tcp import Port
from twisted.web.resource import EncodingResourceWrapper, Resource
from twisted.web.server import GzipEncoderFactory
-from twisted.web.static import File
import synapse
import synapse.config.logger
@@ -33,7 +32,6 @@ from synapse.api.urls import (
MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX,
STATIC_PREFIX,
- WEB_CLIENT_PREFIX,
)
from synapse.app import _base
from synapse.app._base import (
@@ -53,7 +51,6 @@ from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import (
OptionsResource,
RootOptionsRedirectResource,
- RootRedirect,
StaticResource,
)
from synapse.http.site import SynapseSite
@@ -134,15 +131,12 @@ class SynapseHomeServer(HomeServer):
# Try to find something useful to serve at '/':
#
# 1. Redirect to the web client if it is an HTTP(S) URL.
- # 2. Redirect to the web client served via Synapse.
- # 3. Redirect to the static "Synapse is running" page.
- # 4. Do not redirect and use a blank resource.
- if self.config.server.web_client_location_is_redirect:
+ # 2. Redirect to the static "Synapse is running" page.
+ # 3. Do not redirect and use a blank resource.
+ if self.config.server.web_client_location:
root_resource: Resource = RootOptionsRedirectResource(
self.config.server.web_client_location
)
- elif WEB_CLIENT_PREFIX in resources:
- root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
else:
@@ -270,28 +264,6 @@ class SynapseHomeServer(HomeServer):
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
- if name == "webclient":
- # webclient listeners are deprecated as of Synapse v1.51.0, remove it
- # in > v1.53.0.
- webclient_loc = self.config.server.web_client_location
-
- if webclient_loc is None:
- logger.warning(
- "Not enabling webclient resource, as web_client_location is unset."
- )
- elif self.config.server.web_client_location_is_redirect:
- resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc)
- else:
- logger.warning(
- "Running webclient on the same domain is not recommended: "
- "https://github.com/matrix-org/synapse#security-note - "
- "after you move webclient to different host you can set "
- "web_client_location to its full URL to enable redirection."
- )
- # GZip is disabled here due to
- # https://twistedmatrix.com/trac/ticket/7678
- resources[WEB_CLIENT_PREFIX] = File(webclient_loc)
-
if name == "metrics" and self.config.metrics.enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 8c9ff93b2c..7dbebd97b5 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -351,11 +351,13 @@ class AppServiceTransaction:
id: int,
events: List[EventBase],
ephemeral: List[JsonDict],
+ to_device_messages: List[JsonDict],
):
self.service = service
self.id = id
self.events = events
self.ephemeral = ephemeral
+ self.to_device_messages = to_device_messages
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@@ -369,6 +371,7 @@ class AppServiceTransaction:
service=self.service,
events=self.events,
ephemeral=self.ephemeral,
+ to_device_messages=self.to_device_messages,
txn_id=self.id,
)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index def4424af0..73be7ff3d4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -218,8 +218,23 @@ class ApplicationServiceApi(SimpleHttpClient):
service: "ApplicationService",
events: List[EventBase],
ephemeral: List[JsonDict],
+ to_device_messages: List[JsonDict],
txn_id: Optional[int] = None,
) -> bool:
+ """
+ Push data to an application service.
+
+ Args:
+ service: The application service to send to.
+ events: The persistent events to send.
+ ephemeral: The ephemeral events to send.
+ to_device_messages: The to-device messages to send.
+ txn_id: An unique ID to assign to this transaction. Application services should
+ deduplicate transactions received with identitical IDs.
+
+ Returns:
+ True if the task succeeded, False if it failed.
+ """
if service.url is None:
return True
@@ -237,13 +252,15 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it
+ body: Dict[str, List[JsonDict]] = {"events": serialized_events}
if service.supports_ephemeral:
- body = {
- "events": serialized_events,
- "de.sorunome.msc2409.ephemeral": ephemeral,
- }
- else:
- body = {"events": serialized_events}
+ body.update(
+ {
+ # TODO: Update to stable prefixes once MSC2409 completes FCP merge.
+ "de.sorunome.msc2409.ephemeral": ephemeral,
+ "de.sorunome.msc2409.to_device": to_device_messages,
+ }
+ )
try:
await self.put_json(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 185e3a5278..c42fa32fff 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,7 +48,16 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Callable,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Set,
+)
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.appservice.api import ApplicationServiceApi
@@ -71,6 +80,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
# Maximum number of ephemeral events to provide in an AS transaction.
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
+# Maximum number of to-device messages to provide in an AS transaction.
+MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100
+
class ApplicationServiceScheduler:
"""Public facing API for this module. Does the required DI to tie the
@@ -97,15 +109,40 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
- def submit_event_for_as(
- self, service: ApplicationService, event: EventBase
+ def enqueue_for_appservice(
+ self,
+ appservice: ApplicationService,
+ events: Optional[Collection[EventBase]] = None,
+ ephemeral: Optional[Collection[JsonDict]] = None,
+ to_device_messages: Optional[Collection[JsonDict]] = None,
) -> None:
- self.queuer.enqueue_event(service, event)
+ """
+ Enqueue some data to be sent off to an application service.
- def submit_ephemeral_events_for_as(
- self, service: ApplicationService, events: List[JsonDict]
- ) -> None:
- self.queuer.enqueue_ephemeral(service, events)
+ Args:
+ appservice: The application service to create and send a transaction to.
+ events: The persistent room events to send.
+ ephemeral: The ephemeral events to send.
+ to_device_messages: The to-device messages to send. These differ from normal
+ to-device messages sent to clients, as they have 'to_device_id' and
+ 'to_user_id' fields.
+ """
+ # We purposefully allow this method to run with empty events/ephemeral
+ # collections, so that callers do not need to check iterable size themselves.
+ if not events and not ephemeral and not to_device_messages:
+ return
+
+ if events:
+ self.queuer.queued_events.setdefault(appservice.id, []).extend(events)
+ if ephemeral:
+ self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral)
+ if to_device_messages:
+ self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
+ to_device_messages
+ )
+
+ # Kick off a new application service transaction
+ self.queuer.start_background_request(appservice)
class _ServiceQueuer:
@@ -121,13 +158,15 @@ class _ServiceQueuer:
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
+ # dict of {service_id: [to_device_message_json]}
+ self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
- def _start_background_request(self, service: ApplicationService) -> None:
+ def start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight:
return
@@ -136,16 +175,6 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service
)
- def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
- self.queued_events.setdefault(service.id, []).append(event)
- self._start_background_request(service)
-
- def enqueue_ephemeral(
- self, service: ApplicationService, events: List[JsonDict]
- ) -> None:
- self.queued_ephemeral.setdefault(service.id, []).extend(events)
- self._start_background_request(service)
-
async def _send_request(self, service: ApplicationService) -> None:
# sanity-check: we shouldn't get here if this service already has a sender
# running.
@@ -162,11 +191,21 @@ class _ServiceQueuer:
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
- if not events and not ephemeral:
+ all_to_device_messages = self.queued_to_device_messages.get(
+ service.id, []
+ )
+ to_device_messages_to_send = all_to_device_messages[
+ :MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION
+ ]
+ del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
+
+ if not events and not ephemeral and not to_device_messages_to_send:
return
try:
- await self.txn_ctrl.send(service, events, ephemeral)
+ await self.txn_ctrl.send(
+ service, events, ephemeral, to_device_messages_to_send
+ )
except Exception:
logger.exception("AS request failed")
finally:
@@ -198,10 +237,24 @@ class _TransactionController:
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
+ to_device_messages: Optional[List[JsonDict]] = None,
) -> None:
+ """
+ Create a transaction with the given data and send to the provided
+ application service.
+
+ Args:
+ service: The application service to send the transaction to.
+ events: The persistent events to include in the transaction.
+ ephemeral: The ephemeral events to include in the transaction.
+ to_device_messages: The to-device messages to include in the transaction.
+ """
try:
txn = await self.store.create_appservice_txn(
- service=service, events=events, ephemeral=ephemeral or []
+ service=service,
+ events=events,
+ ephemeral=ephemeral or [],
+ to_device_messages=to_device_messages or [],
)
service_is_up = await self._is_service_up(service)
if service_is_up:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index dbaeb10918..e4719d19b8 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -24,8 +24,6 @@ class ExperimentalConfig(Config):
def read_config(self, config: JsonDict, **kwargs):
experimental = config.get("experimental_features") or {}
- # Whether to enable experimental MSC1849 (aka relations) support
- self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
# MSC3440 (thread relation)
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
@@ -54,3 +52,10 @@ class ExperimentalConfig(Config):
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
+
+ # MSC2409 (this setting only relates to optionally sending to-device messages).
+ # Presence, typing and read receipt EDUs are already sent to application services that
+ # have opted in to receive them. If enabled, this adds to-device messages to that list.
+ self.msc2409_to_device_messages_enabled: bool = experimental.get(
+ "msc2409_to_device_messages_enabled", False
+ )
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index 85fb05890d..2ef02b8f55 100644
--- a/synapse/config/modules.py
+++ b/synapse/config/modules.py
@@ -41,9 +41,9 @@ class ModulesConfig(Config):
# documentation on how to configure or create custom modules for Synapse.
#
modules:
- # - module: my_super_module.MySuperClass
- # config:
- # do_thing: true
- # - module: my_other_super_module.SomeClass
- # config: {}
+ #- module: my_super_module.MySuperClass
+ # config:
+ # do_thing: true
+ #- module: my_other_super_module.SomeClass
+ # config: {}
"""
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 36636ab07e..e9ccf1bd62 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -134,6 +134,14 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.003, "burst_count": 5},
)
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {}),
+ defaults={
+ "per_second": self.rc_message.per_second,
+ "burst_count": self.rc_message.burst_count,
+ },
+ )
+
def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
@@ -168,6 +176,9 @@ class RatelimitConfig(Config):
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
# - two for ratelimiting how often invites can be sent in a room or to a
# specific user.
+ # - one for ratelimiting 3PID invites (i.e. invites sent to a third-party ID
+ # such as an email address or a phone number) based on the account that's
+ # sending the invite.
#
# The defaults are as shown below.
#
@@ -217,6 +228,10 @@ class RatelimitConfig(Config):
# per_user:
# per_second: 0.003
# burst_count: 5
+ #
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
# Ratelimiting settings for incoming federation
#
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 7a059c6dec..ea9b50fe97 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -190,6 +190,8 @@ class RegistrationConfig(Config):
# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")
+ self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False)
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -446,6 +448,16 @@ class RegistrationConfig(Config):
# Defaults to true.
#
#auto_join_rooms_for_guests: false
+
+ # Whether to inhibit errors raised when registering a new account if the user ID
+ # already exists. If turned on, that requests to /register/available will always
+ # show a user ID as available, and Synapse won't raise an error when starting
+ # a registration with a user ID that already exists. However, Synapse will still
+ # raise an error if the registration completes and the username conflicts.
+ #
+ # Defaults to false.
+ #
+ #inhibit_user_in_use_error: true
"""
% locals()
)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index f200d0c1f1..7bc9624546 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -179,7 +179,6 @@ KNOWN_RESOURCES = {
"openid",
"replication",
"static",
- "webclient",
}
@@ -489,6 +488,19 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # The maximum size an avatar can have, in bytes.
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size is not None:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ # The MIME types allowed for an avatar.
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes")
+ if self.allowed_avatar_mimetypes and not isinstance(
+ self.allowed_avatar_mimetypes,
+ list,
+ ):
+ raise ConfigError("allowed_avatar_mimetypes must be a list")
+
self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
# no_tls is not really supported any more, but let's grandfather it in
@@ -506,16 +518,12 @@ class ServerConfig(Config):
self.listeners = l2
self.web_client_location = config.get("web_client_location", None)
- self.web_client_location_is_redirect = self.web_client_location and (
+ # Non-HTTP(S) web client location is not supported.
+ if self.web_client_location and not (
self.web_client_location.startswith("http://")
or self.web_client_location.startswith("https://")
- )
- # A non-HTTP(S) web client location is deprecated.
- if self.web_client_location and not self.web_client_location_is_redirect:
- logger.warning(NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING)
-
- # Warn if webclient is configured for a worker.
- _warn_if_webclient_configured(self.listeners)
+ ):
+ raise ConfigError("web_client_location must point to a HTTP(S) URL.")
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
@@ -643,19 +651,6 @@ class ServerConfig(Config):
False,
)
- # List of users trialing the new experimental default push rules. This setting is
- # not included in the sample configuration file on purpose as it's a temporary
- # hack, so that some users can trial the new defaults without impacting every
- # user on the homeserver.
- users_new_default_push_rules: list = (
- config.get("users_new_default_push_rules") or []
- )
- if not isinstance(users_new_default_push_rules, list):
- raise ConfigError("'users_new_default_push_rules' must be a list")
-
- # Turn the list into a set to improve lookup speed.
- self.users_new_default_push_rules: set = set(users_new_default_push_rules)
-
# Whitelist of domain names that given next_link parameters must have
next_link_domain_whitelist: Optional[List[str]] = config.get(
"next_link_domain_whitelist"
@@ -1168,6 +1163,20 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # The largest allowed file size for a user avatar. Defaults to no restriction.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's media repository.
+ #
+ #max_avatar_size: 10M
+
+ # The MIME types allowed for user avatars. Defaults to no restriction.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's media repository.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -1337,11 +1346,16 @@ def parse_listener_def(listener: Any) -> ListenerConfig:
http_config = None
if listener_type == "http":
+ try:
+ resources = [
+ HttpResourceConfig(**res) for res in listener.get("resources", [])
+ ]
+ except ValueError as e:
+ raise ConfigError("Unknown listener resource") from e
+
http_config = HttpListenerConfig(
x_forwarded=listener.get("x_forwarded", False),
- resources=[
- HttpResourceConfig(**res) for res in listener.get("resources", [])
- ],
+ resources=resources,
additional_resources=listener.get("additional_resources", {}),
tag=listener.get("tag"),
)
@@ -1349,30 +1363,6 @@ def parse_listener_def(listener: Any) -> ListenerConfig:
return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
-NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING = """
-Synapse no longer supports serving a web client. To remove this warning,
-configure 'web_client_location' with an HTTP(S) URL.
-"""
-
-
-NO_MORE_WEB_CLIENT_WARNING = """
-Synapse no longer includes a web client. To redirect the root resource to a web client, configure
-'web_client_location'. To remove this warning, remove 'webclient' from the 'listeners'
-configuration.
-"""
-
-
-def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
- for listener in listeners:
- if not listener.http_options:
- continue
- for res in listener.http_options.resources:
- for name in res.names:
- if name == "webclient":
- logger.warning(NO_MORE_WEB_CLIENT_WARNING)
- return
-
-
_MANHOLE_SETTINGS_SCHEMA = {
"type": "object",
"properties": {
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 38f3cf4d33..9acb3c0cc4 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -315,10 +315,11 @@ class EventBase(metaclass=abc.ABCMeta):
redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
room_id: DictProperty[str] = DictProperty("room_id")
sender: DictProperty[str] = DictProperty("sender")
- # TODO state_key should be Optional[str], this is generally asserted in Synapse
- # by calling is_state() first (which ensures this), but it is hard (not possible?)
+ # TODO state_key should be Optional[str]. This is generally asserted in Synapse
+ # by calling is_state() first (which ensures it is not None), but it is hard (not possible?)
# to properly annotate that calling is_state() asserts that state_key exists
- # and is non-None.
+ # and is non-None. It would be better to replace such direct references with
+ # get_state_key() (and a check for None).
state_key: DictProperty[str] = DictProperty("state_key")
type: DictProperty[str] = DictProperty("type")
user_id: DictProperty[str] = DictProperty("sender")
@@ -332,7 +333,11 @@ class EventBase(metaclass=abc.ABCMeta):
return self.content["membership"]
def is_state(self) -> bool:
- return hasattr(self, "state_key") and self.state_key is not None
+ return self.get_state_key() is not None
+
+ def get_state_key(self) -> Optional[str]:
+ """Get the state key of this event, or None if it's not a state event"""
+ return self._dict.get("state_key")
def get_dict(self) -> JsonDict:
d = dict(self._dict)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 0eab1aefd6..5833fee25f 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -163,7 +163,7 @@ class EventContext:
return {
"prev_state_id": prev_state_id,
"event_type": event.type,
- "event_state_key": event.state_key if event.is_state() else None,
+ "event_state_key": event.get_state_key(),
"state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 918adeecf8..243696b357 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import re
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Union,
+)
from frozendict import frozendict
@@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
+if TYPE_CHECKING:
+ from synapse.storage.databases.main.relations import BundledAggregations
+
+
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -376,7 +390,7 @@ class EventClientSerializer:
event: Union[JsonDict, EventBase],
time_now: int,
*,
- bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
+ bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
@@ -415,7 +429,7 @@ class EventClientSerializer:
self,
event: EventBase,
time_now: int,
- aggregations: JsonDict,
+ aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -427,13 +441,18 @@ class EventClientSerializer:
serialized_event: The serialized event which may be modified.
"""
- # Make a copy in-case the object is cached.
- aggregations = aggregations.copy()
+ serialized_aggregations = {}
+
+ if aggregations.annotations:
+ serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
+
+ if aggregations.references:
+ serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
- if RelationTypes.REPLACE in aggregations:
+ if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
- edit = aggregations[RelationTypes.REPLACE]
+ edit = aggregations.replace
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
@@ -451,24 +470,28 @@ class EventClientSerializer:
else:
serialized_event["content"].pop("m.relates_to", None)
- aggregations[RelationTypes.REPLACE] = {
+ serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
}
# If this event is the start of a thread, include a summary of the replies.
- if RelationTypes.THREAD in aggregations:
- # Serialize the latest thread event.
- latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
-
- # Don't bundle aggregations as this could recurse forever.
- aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
- latest_thread_event, time_now, bundle_aggregations=None
- )
+ if aggregations.thread:
+ serialized_aggregations[RelationTypes.THREAD] = {
+ # Don't bundle aggregations as this could recurse forever.
+ "latest_event": self.serialize_event(
+ aggregations.thread.latest_event, time_now, bundle_aggregations=None
+ ),
+ "count": aggregations.thread.count,
+ "current_user_participated": aggregations.thread.current_user_participated,
+ }
# Include the bundled aggregations in the event.
- serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+ if serialized_aggregations:
+ serialized_event["unsigned"].setdefault("m.relations", {}).update(
+ serialized_aggregations
+ )
def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index cf86934968..360d24274a 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
-from typing import Iterable, Union
+from typing import Iterable, Type, Union
import jsonschema
@@ -246,7 +246,7 @@ POWER_LEVELS_SCHEMA = {
# This could return something newer than Draft 7, but that's the current "latest"
# validator.
-def _create_power_level_validator() -> jsonschema.Draft7Validator:
+def _create_power_level_validator() -> Type[jsonschema.Draft7Validator]:
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
# by default jsonschema does not consider a frozendict to be an object so
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 77b936361a..db4fe2c798 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, Iterable, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type
from typing_extensions import Literal
@@ -36,17 +36,19 @@ from synapse.http.servlet import (
parse_integer_from_args,
parse_string_from_args,
)
-from synapse.server import HomeServer
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.ratelimitutils import FederationRateLimiter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
- def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None):
+ def __init__(self, hs: "HomeServer", servlet_groups: Optional[List[str]] = None):
"""Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in
@@ -113,7 +115,7 @@ class PublicRoomList(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -203,7 +205,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -251,7 +253,7 @@ class OpenIdUserInfo(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -297,7 +299,7 @@ DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
def register_servlets(
- hs: HomeServer,
+ hs: "HomeServer",
resource: HttpServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index da1fbf8b63..dff2b68359 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -15,7 +15,8 @@
import functools
import logging
import re
-from typing import Any, Awaitable, Callable, Optional, Tuple, cast
+import time
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
@@ -24,16 +25,20 @@ from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
+ active_span,
set_tag,
span_context_from_request,
+ start_active_span,
start_active_span_follows_from,
whitelisted_homeserver,
)
-from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import parse_and_validate_server_name
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -46,7 +51,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator:
- def __init__(self, hs: HomeServer):
+ def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
@@ -114,11 +119,11 @@ class Authenticator:
# alive
retry_timings = await self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings.retry_last_ts:
- run_in_background(self._reset_retry_timings, origin)
+ run_in_background(self.reset_retry_timings, origin)
return origin
- async def _reset_retry_timings(self, origin: str) -> None:
+ async def reset_retry_timings(self, origin: str) -> None:
try:
logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0)
@@ -227,7 +232,7 @@ class BaseFederationServlet:
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -263,9 +268,10 @@ class BaseFederationServlet:
content = parse_json_object_from_request(request)
try:
- origin: Optional[str] = await authenticator.authenticate_request(
- request, content
- )
+ with start_active_span("authenticate_request"):
+ origin: Optional[str] = await authenticator.authenticate_request(
+ request, content
+ )
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
@@ -280,32 +286,57 @@ class BaseFederationServlet:
# update the active opentracing span with the authenticated entity
set_tag("authenticated_entity", origin)
- # if the origin is authenticated and whitelisted, link to its span context
+ # if the origin is authenticated and whitelisted, use its span context
+ # as the parent.
context = None
if origin and whitelisted_homeserver(origin):
context = span_context_from_request(request)
- scope = start_active_span_follows_from(
- "incoming-federation-request", contexts=(context,) if context else ()
- )
+ if context:
+ servlet_span = active_span()
+ # a scope which uses the origin's context as a parent
+ processing_start_time = time.time()
+ scope = start_active_span_follows_from(
+ "incoming-federation-request",
+ child_of=context,
+ contexts=(servlet_span,),
+ start_time=processing_start_time,
+ )
- with scope:
- if origin and self.RATELIMIT:
- with ratelimiter.ratelimit(origin) as d:
- await d
- if request._disconnected:
- logger.warning(
- "client disconnected before we started processing "
- "request"
+ else:
+ # just use our context as a parent
+ scope = start_active_span(
+ "incoming-federation-request",
+ )
+
+ try:
+ with scope:
+ if origin and self.RATELIMIT:
+ with ratelimiter.ratelimit(origin) as d:
+ await d
+ if request._disconnected:
+ logger.warning(
+ "client disconnected before we started processing "
+ "request"
+ )
+ return None
+ response = await func(
+ origin, content, request.args, *args, **kwargs
)
- return None
+ else:
response = await func(
origin, content, request.args, *args, **kwargs
)
- else:
- response = await func(
- origin, content, request.args, *args, **kwargs
+ finally:
+ # if we used the origin's context as the parent, add a new span using
+ # the servlet span as a parent, so that we have a link
+ if context:
+ scope2 = start_active_span_follows_from(
+ "process-federation_request",
+ contexts=(scope.span,),
+ start_time=processing_start_time,
)
+ scope2.close()
return response
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index beadfa422b..d86dfede4e 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
from typing_extensions import Literal
@@ -30,11 +40,13 @@ from synapse.http.servlet import (
parse_string_from_args,
parse_strings_from_args,
)
-from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
issue_8631_logger = logging.getLogger("synapse.8631_debug")
@@ -47,7 +59,7 @@ class BaseFederationServerServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -97,11 +109,11 @@ class FederationSendServlet(BaseFederationServerServlet):
)
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+ DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"]
device_list_updates = [
edu.content
for edu in transaction_data.get("edus", [])
- if edu.edu_type in DEVICE_UPDATE_EDUS
+ if edu.get("edu_type") in DEVICE_UPDATE_EDUS
]
if device_list_updates:
issue_8631_logger.debug(
@@ -596,7 +608,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -670,7 +682,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
@@ -706,7 +718,7 @@ class RoomComplexityServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py
index a12cd18d58..496472e1dc 100644
--- a/synapse/federation/transport/server/groups_local.py
+++ b/synapse/federation/transport/server/groups_local.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple, Type
+from typing import TYPE_CHECKING, Dict, List, Tuple, Type
from synapse.api.errors import SynapseError
from synapse.federation.transport.server._base import (
@@ -19,10 +19,12 @@ from synapse.federation.transport.server._base import (
BaseFederationServlet,
)
from synapse.handlers.groups_local import GroupsLocalHandler
-from synapse.server import HomeServer
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class BaseGroupsLocalServlet(BaseFederationServlet):
"""Abstract base class for federation servlet classes which provides a groups local handler.
@@ -32,7 +34,7 @@ class BaseGroupsLocalServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py
index b30e92a5eb..851b50152e 100644
--- a/synapse/federation/transport/server/groups_server.py
+++ b/synapse/federation/transport/server/groups_server.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple, Type
+from typing import TYPE_CHECKING, Dict, List, Tuple, Type
from typing_extensions import Literal
@@ -22,10 +22,12 @@ from synapse.federation.transport.server._base import (
BaseFederationServlet,
)
from synapse.http.servlet import parse_string_from_args
-from synapse.server import HomeServer
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class BaseGroupsServerServlet(BaseFederationServlet):
"""Abstract base class for federation servlet classes which provides a groups server handler.
@@ -35,7 +37,7 @@ class BaseGroupsServerServlet(BaseFederationServlet):
def __init__(
self,
- hs: HomeServer,
+ hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 7833e77e2b..0fb919acf6 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -55,6 +55,9 @@ class ApplicationServicesHandler:
self.clock = hs.get_clock()
self.notify_appservices = hs.config.appservice.notify_appservices
self.event_sources = hs.get_event_sources()
+ self._msc2409_to_device_messages_enabled = (
+ hs.config.experimental.msc2409_to_device_messages_enabled
+ )
self.current_max = 0
self.is_processing = False
@@ -132,7 +135,9 @@ class ApplicationServicesHandler:
# Fork off pushes to these services
for service in services:
- self.scheduler.submit_event_for_as(service, event)
+ self.scheduler.enqueue_for_appservice(
+ service, events=[event]
+ )
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@@ -199,8 +204,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
- `stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
- value for `stream_key` will cause this function to return early.
+ `stream_key` can be "typing_key", "receipt_key", "presence_key" or
+ "to_device_key". Any other value for `stream_key` will cause this function
+ to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@@ -216,8 +222,15 @@ class ApplicationServicesHandler:
if not self.notify_appservices:
return
- # Ignore any unsupported streams
- if stream_key not in ("typing_key", "receipt_key", "presence_key"):
+ # Notify appservices of updates in ephemeral event streams.
+ # Only the following streams are currently supported.
+ # FIXME: We should use constants for these values.
+ if stream_key not in (
+ "typing_key",
+ "receipt_key",
+ "presence_key",
+ "to_device_key",
+ ):
return
# Assert that new_token is an integer (and not a RoomStreamToken).
@@ -233,6 +246,13 @@ class ApplicationServicesHandler:
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)
+ # Ignore to-device messages if the feature flag is not enabled
+ if (
+ stream_key == "to_device_key"
+ and not self._msc2409_to_device_messages_enabled
+ ):
+ return
+
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
@@ -266,7 +286,7 @@ class ApplicationServicesHandler:
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
if stream_key == "typing_key":
- # Note that we don't persist the token (via set_type_stream_id_for_appservice)
+ # Note that we don't persist the token (via set_appservice_stream_type_pos)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
#
@@ -274,7 +294,7 @@ class ApplicationServicesHandler:
# and, if they apply to this application service, send it off.
events = await self._handle_typing(service, new_token)
if events:
- self.scheduler.submit_ephemeral_events_for_as(service, events)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=events)
continue
# Since we read/update the stream position for this AS/stream
@@ -285,28 +305,37 @@ class ApplicationServicesHandler:
):
if stream_key == "receipt_key":
events = await self._handle_receipts(service, new_token)
- if events:
- self.scheduler.submit_ephemeral_events_for_as(
- service, events
- )
+ self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
- await self.store.set_type_stream_id_for_appservice(
+ await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token
)
elif stream_key == "presence_key":
events = await self._handle_presence(service, users, new_token)
- if events:
- self.scheduler.submit_ephemeral_events_for_as(
- service, events
- )
+ self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
- await self.store.set_type_stream_id_for_appservice(
+ await self.store.set_appservice_stream_type_pos(
service, "presence", new_token
)
+ elif stream_key == "to_device_key":
+ # Retrieve a list of to-device message events, as well as the
+ # maximum stream token of the messages we were able to retrieve.
+ to_device_messages = await self._get_to_device_messages(
+ service, new_token, users
+ )
+ self.scheduler.enqueue_for_appservice(
+ service, to_device_messages=to_device_messages
+ )
+
+ # Persist the latest handled stream token for this appservice
+ await self.store.set_appservice_stream_type_pos(
+ service, "to_device", new_token
+ )
+
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@@ -440,6 +469,79 @@ class ApplicationServicesHandler:
return events
+ async def _get_to_device_messages(
+ self,
+ service: ApplicationService,
+ new_token: int,
+ users: Collection[Union[str, UserID]],
+ ) -> List[JsonDict]:
+ """
+ Given an application service, determine which events it should receive
+ from those between the last-recorded to-device message stream token for this
+ appservice and the given stream token.
+
+ Args:
+ service: The application service to check for which events it should receive.
+ new_token: The latest to-device event stream token.
+ users: The users to be notified for the new to-device messages
+ (ie, the recipients of the messages).
+
+ Returns:
+ A list of JSON dictionaries containing data derived from the to-device events
+ that should be sent to the given application service.
+ """
+ # Get the stream token that this application service has processed up until
+ from_key = await self.store.get_type_stream_id_for_appservice(
+ service, "to_device"
+ )
+
+ # Filter out users that this appservice is not interested in
+ users_appservice_is_interested_in: List[str] = []
+ for user in users:
+ # FIXME: We should do this farther up the call stack. We currently repeat
+ # this operation in _handle_presence.
+ if isinstance(user, UserID):
+ user = user.to_string()
+
+ if service.is_interested_in_user(user):
+ users_appservice_is_interested_in.append(user)
+
+ if not users_appservice_is_interested_in:
+ # Return early if the AS was not interested in any of these users
+ return []
+
+ # Retrieve the to-device messages for each user
+ recipient_device_to_messages = await self.store.get_messages_for_user_devices(
+ users_appservice_is_interested_in,
+ from_key,
+ new_token,
+ )
+
+ # According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
+ # to the event JSON so that the application service will know which user/device
+ # combination this messages was intended for.
+ #
+ # So we mangle this dict into a flat list of to-device messages with the relevant
+ # user ID and device ID embedded inside each message dict.
+ message_payload: List[JsonDict] = []
+ for (
+ user_id,
+ device_id,
+ ), messages in recipient_device_to_messages.items():
+ for message_json in messages:
+ # Remove 'message_id' from the to-device message, as it's an internal ID
+ message_json.pop("message_id", None)
+
+ message_payload.append(
+ {
+ "to_user_id": user_id,
+ "to_device_id": device_id,
+ **message_json,
+ }
+ )
+
+ return message_payload
+
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bd1a322563..e32c93e234 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+ [JsonDict, JsonDict],
+ Awaitable[Optional[str]],
+]
class PasswordAuthProvider:
@@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+ self.get_username_for_registration_callbacks: List[
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK
+ ] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
+ get_username_for_registration: Optional[
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
@@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
+ if get_username_for_registration is not None:
+ self.get_username_for_registration_callbacks.append(
+ get_username_for_registration,
+ )
+
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
+
+ async def get_username_for_registration(
+ self,
+ uia_results: JsonDict,
+ params: JsonDict,
+ ) -> Optional[str]:
+ """Defines the username to use when registering the user, using the credentials
+ and parameters provided during the UIA flow.
+
+ Stops at the first callback that returns a string.
+
+ Args:
+ uia_results: The credentials provided during the UIA flow.
+ params: The parameters provided by the registration request.
+
+ Returns:
+ The localpart to use when registering this user, or None if no module
+ returned a localpart.
+ """
+ for callback in self.get_username_for_registration_callbacks:
+ try:
+ res = await callback(uia_results, params)
+
+ if isinstance(res, str):
+ return res
+ elif res is not None:
+ # mypy complains that this line is unreachable because it assumes the
+ # data returned by the module fits the expected type. We just want
+ # to make sure this is the case.
+ logger.warning( # type: ignore[unreachable]
+ "Ignoring non-string value returned by"
+ " get_username_for_registration callback %s: %s",
+ callback,
+ res,
+ )
+ except Exception as e:
+ logger.error(
+ "Module raised an exception in get_username_for_registration: %s",
+ e,
+ )
+ raise SynapseError(code=500, msg="Internal Server Error")
+
+ return None
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index bee62cf360..7a13d76a68 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -157,6 +157,9 @@ class DeactivateAccountHandler:
# Mark the user as deactivated.
await self.store.set_user_deactivated_status(user_id, True)
+ # Remove account data (including ignored users and push rules).
+ await self.store.purge_account_data_for_user(user_id)
+
return identity_server_supports_unbinding
async def _reject_pending_invites_for_user(self, user_id: str) -> None:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6b5a6ded8b..36e3ad2ba9 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -31,6 +31,8 @@ from synapse.types import (
create_requester,
get_domain_from_id,
)
+from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import parse_and_validate_mxc_uri
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -64,6 +66,11 @@ class ProfileHandler:
self.user_directory_handler = hs.get_user_directory_handler()
self.request_ratelimiter = hs.get_request_ratelimiter()
+ self.max_avatar_size = hs.config.server.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
+
+ self.server_name = hs.config.server.server_name
+
if hs.config.worker.run_background_tasks:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
@@ -286,6 +293,9 @@ class ProfileHandler:
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ if not await self.check_avatar_size_and_mime_type(new_avatar_url):
+ raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN)
+
avatar_url_to_set: Optional[str] = new_avatar_url
if new_avatar_url == "":
avatar_url_to_set = None
@@ -307,6 +317,63 @@ class ProfileHandler:
await self._update_join_states(requester, target_user)
+ @cached()
+ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
+ """Check that the size and content type of the avatar at the given MXC URI are
+ within the configured limits.
+
+ Args:
+ mxc: The MXC URI at which the avatar can be found.
+
+ Returns:
+ A boolean indicating whether the file can be allowed to be set as an avatar.
+ """
+ if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
+ return True
+
+ server_name, _, media_id = parse_and_validate_mxc_uri(mxc)
+
+ if server_name == self.server_name:
+ media_info = await self.store.get_local_media(media_id)
+ else:
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
+
+ if media_info is None:
+ # Both configuration options need to access the file's metadata, and
+ # retrieving remote avatars just for this becomes a bit of a faff, especially
+ # if e.g. the file is too big. It's also generally safe to assume most files
+ # used as avatar are uploaded locally, or if the upload didn't happen as part
+ # of a PUT request on /avatar_url that the file was at least previewed by the
+ # user locally (and therefore downloaded to the remote media cache).
+ logger.warning("Forbidding avatar change to %s: avatar not on server", mxc)
+ return False
+
+ if self.max_avatar_size:
+ # Ensure avatar does not exceed max allowed avatar size
+ if media_info["media_length"] > self.max_avatar_size:
+ logger.warning(
+ "Forbidding avatar change to %s: %d bytes is above the allowed size "
+ "limit",
+ mxc,
+ media_info["media_length"],
+ )
+ return False
+
+ if self.allowed_avatar_mimetypes:
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ logger.warning(
+ "Forbidding avatar change to %s: mimetype %s not allowed",
+ mxc,
+ media_info["media_type"],
+ )
+ return False
+
+ return True
+
async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f08a516a75..a719d5eef3 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -132,6 +132,7 @@ class RegistrationHandler:
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
+ inhibit_user_in_use_error: bool = False,
) -> None:
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
@@ -171,21 +172,22 @@ class RegistrationHandler:
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
- if not guest_access_token:
+ if not inhibit_user_in_use_error and not guest_access_token:
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
- user_data = await self.auth.get_user_by_access_token(guest_access_token)
- if (
- not user_data.is_guest
- or UserID.from_string(user_data.user_id).localpart != localpart
- ):
- raise AuthError(
- 403,
- "Cannot register taken user ID without valid guest "
- "credentials for that user.",
- errcode=Codes.FORBIDDEN,
- )
+ if guest_access_token:
+ user_data = await self.auth.get_user_by_access_token(guest_access_token)
+ if (
+ not user_data.is_guest
+ or UserID.from_string(user_data.user_id).localpart != localpart
+ ):
+ raise AuthError(
+ 403,
+ "Cannot register taken user ID without valid guest "
+ "credentials for that user.",
+ errcode=Codes.FORBIDDEN,
+ )
if guest_access_token is None:
try:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f963078e59..1420d67729 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -30,6 +30,7 @@ from typing import (
Tuple,
)
+import attr
from typing_extensions import TypedDict
from synapse.api.constants import (
@@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@@ -90,6 +92,17 @@ id_server_scheme = "https://"
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventContext:
+ events_before: List[EventBase]
+ event: EventBase
+ events_after: List[EventBase]
+ state: List[EventBase]
+ aggregations: Dict[str, BundledAggregations]
+ start: str
+ end: str
+
+
class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ class RoomContextHandler:
limit: int,
event_filter: Optional[Filter],
use_admin_priviledge: bool = False,
- ) -> Optional[JsonDict]:
+ ) -> Optional[EventContext]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -1167,38 +1180,28 @@ class RoomContextHandler:
results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
+ events_before = results.events_before
+ events_after = results.events_after
if event_filter:
- results["events_before"] = await event_filter.filter(
- results["events_before"]
- )
- results["events_after"] = await event_filter.filter(results["events_after"])
+ events_before = await event_filter.filter(events_before)
+ events_after = await event_filter.filter(events_after)
- results["events_before"] = await filter_evts(results["events_before"])
- results["events_after"] = await filter_evts(results["events_after"])
+ events_before = await filter_evts(events_before)
+ events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
- results["event"] = filtered[0]
+ event = filtered[0]
# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
- [results["event"]], user.to_string()
+ itertools.chain(events_before, (event,), events_after),
+ user.to_string(),
)
- aggregations.update(
- await self.store.get_bundled_aggregations(
- results["events_before"], user.to_string()
- )
- )
- aggregations.update(
- await self.store.get_bundled_aggregations(
- results["events_after"], user.to_string()
- )
- )
- results["aggregations"] = aggregations
- if results["events_after"]:
- last_event_id = results["events_after"][-1].event_id
+ if events_after:
+ last_event_id = events_after[-1].event_id
else:
last_event_id = event_id
@@ -1206,9 +1209,9 @@ class RoomContextHandler:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
- results["events_before"],
- (results["event"],),
- results["events_after"],
+ events_before,
+ (event,),
+ events_after,
)
)
else:
@@ -1226,21 +1229,23 @@ class RoomContextHandler:
if event_filter:
state_events = await event_filter.filter(state_events)
- results["state"] = await filter_evts(state_events)
-
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
token = StreamToken.START
- results["start"] = await token.copy_and_replace(
- "room_key", results["start"]
- ).to_string(self.store)
-
- results["end"] = await token.copy_and_replace(
- "room_key", results["end"]
- ).to_string(self.store)
-
- return results
+ return EventContext(
+ events_before=events_before,
+ event=event,
+ events_after=events_after,
+ state=await filter_evts(state_events),
+ aggregations=aggregations,
+ start=await token.copy_and_replace("room_key", results.start).to_string(
+ self.store
+ ),
+ end=await token.copy_and_replace("room_key", results.end).to_string(
+ self.store
+ ),
+ )
class TimestampLookupHandler:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6aa910dd10..efe6b4c9aa 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -116,6 +116,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
+ self._third_party_invite_limiter = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second,
+ burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count,
+ )
+
self.request_ratelimiter = hs.get_request_ratelimiter()
@abc.abstractmethod
@@ -590,6 +597,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
errcode=Codes.BAD_JSON,
)
+ if "avatar_url" in content:
+ if not await self.profile_handler.check_avatar_size_and_mime_type(
+ content["avatar_url"],
+ ):
+ raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN)
+
# The event content should *not* include the authorising user as
# it won't be properly signed. Strip it out since it might come
# back from a client updating a display name / avatar.
@@ -1289,7 +1302,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- await self.request_ratelimiter.ratelimit(requester)
+ await self._third_party_invite_limiter.ratelimit(requester)
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0b153a6822..02bb5ae72f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -361,36 +361,37 @@ class SearchHandler:
logger.info(
"Context for search returned %d and %d events",
- len(res["events_before"]),
- len(res["events_after"]),
+ len(res.events_before),
+ len(res.events_after),
)
- res["events_before"] = await filter_events_for_client(
- self.storage, user.to_string(), res["events_before"]
+ events_before = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_before
)
- res["events_after"] = await filter_events_for_client(
- self.storage, user.to_string(), res["events_after"]
+ events_after = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_after
)
- res["start"] = await now_token.copy_and_replace(
- "room_key", res["start"]
- ).to_string(self.store)
-
- res["end"] = await now_token.copy_and_replace(
- "room_key", res["end"]
- ).to_string(self.store)
+ context = {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": await now_token.copy_and_replace(
+ "room_key", res.start
+ ).to_string(self.store),
+ "end": await now_token.copy_and_replace(
+ "room_key", res.end
+ ).to_string(self.store),
+ }
if include_profile:
senders = {
ev.sender
- for ev in itertools.chain(
- res["events_before"], [event], res["events_after"]
- )
+ for ev in itertools.chain(events_before, [event], events_after)
}
- if res["events_after"]:
- last_event_id = res["events_after"][-1].event_id
+ if events_after:
+ last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id
@@ -402,7 +403,7 @@ class SearchHandler:
last_event_id, state_filter
)
- res["profile_info"] = {
+ context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
@@ -411,7 +412,7 @@ class SearchHandler:
if s.type == EventTypes.Member and s.state_key in senders
}
- contexts[event.event_id] = res
+ contexts[event.event_id] = context
else:
contexts = {}
@@ -421,10 +422,10 @@ class SearchHandler:
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
- context["events_before"], time_now
+ context["events_before"], time_now # type: ignore[arg-type]
)
context["events_after"] = self._event_serializer.serialize_events(
- context["events_after"], time_now
+ context["events_after"], time_now # type: ignore[arg-type]
)
state_results = {}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ffc6b748e8..aa9a76f8a9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -100,7 +101,7 @@ class TimelineBatch:
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
- bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
+ bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -1347,8 +1348,8 @@ class SyncHandler:
if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key)
- if since_stream_id != int(now_token.to_device_key):
- messages, stream_id = await self.store.get_new_messages_for_device(
+ if device_id is not None and since_stream_id != int(now_token.to_device_key):
+ messages, stream_id = await self.store.get_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
@@ -1619,7 +1620,7 @@ class SyncHandler:
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
)
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca33b45cb2..743a7ffcb1 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -731,15 +731,24 @@ class SimpleHttpClient:
# straight back in again
try:
- length = await make_deferred_yieldable(
- read_body_with_max_size(response, output_stream, max_size)
- )
+ d = read_body_with_max_size(response, output_stream, max_size)
+
+ # Ensure that the body is not read forever.
+ d = timeout_deferred(d, 30, self.hs.get_reactor())
+
+ length = await make_deferred_yieldable(d)
except BodyExceededMaxSize:
raise SynapseError(
HTTPStatus.BAD_GATEWAY,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
+ except defer.TimeoutError:
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ "Requested file took too long to download",
+ Codes.TOO_LARGE,
+ )
except Exception as e:
raise SynapseError(
HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c180a1d323..40f6c04894 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -407,7 +407,10 @@ class SynapseRequest(Request):
user_agent = get_request_user_agent(self, "-")
- code = str(self.code)
+ # int(self.code) looks redundant, because self.code is already an int.
+ # But self.code might be an HTTPStatus (which inherits from int)---which has
+ # a different string representation. So ensure we really have an integer.
+ code = str(int(self.code))
if not self.finished:
# we didn't send the full response before we gave up (presumably because
# the connection dropped)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index b240d2d21d..3ebed5c161 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -443,10 +443,14 @@ def start_active_span(
start_time=None,
ignore_active_span=False,
finish_on_close=True,
+ *,
+ tracer=None,
):
- """Starts an active opentracing span. Note, the scope doesn't become active
- until it has been entered, however, the span starts from the time this
- message is called.
+ """Starts an active opentracing span.
+
+ Records the start time for the span, and sets it as the "active span" in the
+ scope manager.
+
Args:
See opentracing.tracer
Returns:
@@ -456,7 +460,11 @@ def start_active_span(
if opentracing is None:
return noop_context_manager() # type: ignore[unreachable]
- return opentracing.tracer.start_active_span(
+ if tracer is None:
+ # use the global tracer by default
+ tracer = opentracing.tracer
+
+ return tracer.start_active_span(
operation_name,
child_of=child_of,
references=references,
@@ -468,21 +476,42 @@ def start_active_span(
def start_active_span_follows_from(
- operation_name: str, contexts: Collection, inherit_force_tracing=False
+ operation_name: str,
+ contexts: Collection,
+ child_of=None,
+ start_time: Optional[float] = None,
+ *,
+ inherit_force_tracing=False,
+ tracer=None,
):
"""Starts an active opentracing span, with additional references to previous spans
Args:
operation_name: name of the operation represented by the new span
contexts: the previous spans to inherit from
+
+ child_of: optionally override the parent span. If unset, the currently active
+ span will be the parent. (If there is no currently active span, the first
+ span in `contexts` will be the parent.)
+
+ start_time: optional override for the start time of the created span. Seconds
+ since the epoch.
+
inherit_force_tracing: if set, and any of the previous contexts have had tracing
forced, the new span will also have tracing forced.
+ tracer: override the opentracing tracer. By default the global tracer is used.
"""
if opentracing is None:
return noop_context_manager() # type: ignore[unreachable]
references = [opentracing.follows_from(context) for context in contexts]
- scope = start_active_span(operation_name, references=references)
+ scope = start_active_span(
+ operation_name,
+ child_of=child_of,
+ references=references,
+ start_time=start_time,
+ tracer=tracer,
+ )
if inherit_force_tracing and any(
is_context_forced_tracing(ctx) for ctx in contexts
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index db8ca2c049..d57e7c5324 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -28,8 +28,9 @@ class LogContextScopeManager(ScopeManager):
The LogContextScopeManager tracks the active scope in opentracing
by using the log contexts which are native to synapse. This is so
that the basic opentracing api can be used across twisted defereds.
- (I would love to break logcontexts and this into an OS package. but
- let's wait for twisted's contexts to be released.)
+
+ It would be nice just to use opentracing's ContextVarsScopeManager,
+ but currently that doesn't work due to https://twistedmatrix.com/trac/ticket/10301.
"""
def __init__(self, config):
@@ -65,29 +66,45 @@ class LogContextScopeManager(ScopeManager):
Scope.close() on the returned instance.
"""
- enter_logcontext = False
ctx = current_context()
if not ctx:
- # We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span) # type: ignore[arg-type]
- elif ctx.scope is not None:
- # We want the logging scope to look exactly the same so we give it
- # a blank suffix
+
+ if ctx.scope is not None:
+ # start a new logging context as a child of the existing one.
+ # Doing so -- rather than updating the existing logcontext -- means that
+ # creating several concurrent spans under the same logcontext works
+ # correctly.
ctx = nested_logging_context("")
enter_logcontext = True
+ else:
+ # if there is no span currently associated with the current logcontext, we
+ # just store the scope in it.
+ #
+ # This feels a bit dubious, but it does hack around a problem where a
+ # span outlasts its parent logcontext (which would otherwise lead to
+ # "Re-starting finished log context" errors).
+ enter_logcontext = False
scope = _LogContextScope(self, span, ctx, enter_logcontext, finish_on_close)
ctx.scope = scope
+ if enter_logcontext:
+ ctx.__enter__()
+
return scope
class _LogContextScope(Scope):
"""
- A custom opentracing scope. The only significant difference is that it will
- close the log context it's related to if the logcontext was created specifically
- for this scope.
+ A custom opentracing scope, associated with a LogContext
+
+ * filters out _DefGen_Return exceptions which arise from calling
+ `defer.returnValue` in Twisted code
+
+ * When the scope is closed, the logcontext's active scope is reset to None.
+ and - if enter_logcontext was set - the logcontext is finished too.
"""
def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
@@ -101,8 +118,7 @@ class _LogContextScope(Scope):
logcontext (LogContext):
the logcontext to which this scope is attached.
enter_logcontext (Boolean):
- if True the logcontext will be entered and exited when the scope
- is entered and exited respectively
+ if True the logcontext will be exited when the scope is finished
finish_on_close (Boolean):
if True finish the span when the scope is closed
"""
@@ -111,26 +127,28 @@ class _LogContextScope(Scope):
self._finish_on_close = finish_on_close
self._enter_logcontext = enter_logcontext
- def __enter__(self):
- if self._enter_logcontext:
- self.logcontext.__enter__()
+ def __exit__(self, exc_type, value, traceback):
+ if exc_type == twisted.internet.defer._DefGen_Return:
+ # filter out defer.returnValue() calls
+ exc_type = value = traceback = None
+ super().__exit__(exc_type, value, traceback)
- return self
-
- def __exit__(self, type, value, traceback):
- if type == twisted.internet.defer._DefGen_Return:
- super().__exit__(None, None, None)
- else:
- super().__exit__(type, value, traceback)
- if self._enter_logcontext:
- self.logcontext.__exit__(type, value, traceback)
- else: # the logcontext existed before the creation of the scope
- self.logcontext.scope = None
+ def __str__(self):
+ return f"Scope<{self.span}>"
def close(self):
- if self.manager.active is not self:
- logger.error("Tried to close a non-active scope!")
- return
+ active_scope = self.manager.active
+ if active_scope is not self:
+ logger.error(
+ "Closing scope %s which is not the currently-active one %s",
+ self,
+ active_scope,
+ )
if self._finish_on_close:
self.span.finish()
+
+ self.logcontext.scope = None
+
+ if self._enter_logcontext:
+ self.logcontext.__exit__(None, None, None)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 9e6c1b2f3b..cca084c18c 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -30,6 +30,7 @@ from typing import (
Type,
TypeVar,
Union,
+ cast,
)
import attr
@@ -60,7 +61,7 @@ all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
-class RegistryProxy:
+class _RegistryProxy:
@staticmethod
def collect() -> Iterable[Metric]:
for metric in REGISTRY.collect():
@@ -68,6 +69,13 @@ class RegistryProxy:
yield metric
+# A little bit nasty, but collect() above is static so a Protocol doesn't work.
+# _RegistryProxy matches the signature of a CollectorRegistry instance enough
+# for it to be usable in the contexts in which we use it.
+# TODO Do something nicer about this.
+RegistryProxy = cast(CollectorRegistry, _RegistryProxy)
+
+
@attr.s(slots=True, hash=True, auto_attribs=True)
class LaterGauge:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 662e60bc33..29fbc73c97 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -71,6 +71,7 @@ from synapse.handlers.account_validity import (
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
@@ -177,6 +178,7 @@ class ModuleApi:
self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
+ self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
@@ -310,6 +312,9 @@ class ModuleApi:
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
+ get_username_for_registration: Optional[
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
"""Registers callbacks for password auth provider capabilities.
@@ -319,6 +324,7 @@ class ModuleApi:
check_3pid_auth=check_3pid_auth,
on_logged_out=on_logged_out,
auth_checkers=auth_checkers,
+ get_username_for_registration=get_username_for_registration,
)
def register_background_update_controller_callbacks(
@@ -395,6 +401,32 @@ class ModuleApi:
"""
return self._hs.config.email.email_app_name
+ @property
+ def server_name(self) -> str:
+ """The server name for the local homeserver.
+
+ Added in Synapse v1.53.0.
+ """
+ return self._server_name
+
+ @property
+ def worker_name(self) -> Optional[str]:
+ """The name of the worker this specific instance is running as per the
+ "worker_name" configuration setting, or None if it's the main process.
+
+ Added in Synapse v1.53.0.
+ """
+ return self._hs.config.worker.worker_name
+
+ @property
+ def worker_app(self) -> Optional[str]:
+ """The name of the worker app this specific instance is running as per the
+ "worker_app" configuration setting, or None if it's the main process.
+
+ Added in Synapse v1.53.0.
+ """
+ return self._hs.config.worker.worker_app
+
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get user info by user_id
@@ -1202,6 +1234,22 @@ class ModuleApi:
"""
return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+ async def check_username(self, username: str) -> None:
+ """Checks if the provided username uses the grammar defined in the Matrix
+ specification, and is already being used by an existing user.
+
+ Added in Synapse v1.52.0.
+
+ Args:
+ username: The username to check. This is the local part of the user's full
+ Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com".
+
+ Raises:
+ SynapseError with the errcode "M_USER_IN_USE" if the username is already in
+ use.
+ """
+ await self._registration_handler.check_username(username)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 632b2245ef..5988c67d90 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -461,7 +461,9 @@ class Notifier:
users,
)
except Exception:
- logger.exception("Error notifying application services of event")
+ logger.exception(
+ "Error notifying application services of ephemeral events"
+ )
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6211506990..910b05c0da 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -20,15 +20,11 @@ from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
-def list_with_base_rules(
- rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
-) -> List[Dict[str, Any]]:
+def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules
Args:
rawrules: The rules the user has modified or set.
- use_new_defaults: Whether to use the new experimental default rules when
- appending or prepending default rules.
Returns:
A new list with the rules set by the user combined with the defaults.
@@ -48,9 +44,7 @@ def list_with_base_rules(
ruleslist.extend(
make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- use_new_defaults,
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
)
)
@@ -61,7 +55,6 @@ def list_with_base_rules(
make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
- use_new_defaults,
)
)
current_prio_class -= 1
@@ -70,7 +63,6 @@ def list_with_base_rules(
make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
- use_new_defaults,
)
)
@@ -79,18 +71,14 @@ def list_with_base_rules(
while current_prio_class > 0:
ruleslist.extend(
make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- use_new_defaults,
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
)
)
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(
make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- use_new_defaults,
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
)
)
@@ -98,24 +86,14 @@ def list_with_base_rules(
def make_base_append_rules(
- kind: str,
- modified_base_rules: Dict[str, Dict[str, Any]],
- use_new_defaults: bool = False,
+ kind: str, modified_base_rules: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
- rules = (
- NEW_APPEND_OVERRIDE_RULES
- if use_new_defaults
- else BASE_APPEND_OVERRIDE_RULES
- )
+ rules = BASE_APPEND_OVERRIDE_RULES
elif kind == "underride":
- rules = (
- NEW_APPEND_UNDERRIDE_RULES
- if use_new_defaults
- else BASE_APPEND_UNDERRIDE_RULES
- )
+ rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == "content":
rules = BASE_APPEND_CONTENT_RULES
@@ -134,7 +112,6 @@ def make_base_append_rules(
def make_base_prepend_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
- use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = []
@@ -301,135 +278,6 @@ BASE_APPEND_OVERRIDE_RULES = [
]
-NEW_APPEND_OVERRIDE_RULES = [
- {
- "rule_id": "global/override/.m.rule.encrypted",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.encrypted",
- "_id": "_encrypted",
- }
- ],
- "actions": ["notify"],
- },
- {
- "rule_id": "global/override/.m.rule.suppress_notices",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.message",
- "_id": "_suppress_notices_type",
- },
- {
- "kind": "event_match",
- "key": "content.msgtype",
- "pattern": "m.notice",
- "_id": "_suppress_notices",
- },
- ],
- "actions": [],
- },
- {
- "rule_id": "global/underride/.m.rule.suppress_edits",
- "conditions": [
- {
- "kind": "event_match",
- "key": "m.relates_to.m.rel_type",
- "pattern": "m.replace",
- "_id": "_suppress_edits",
- }
- ],
- "actions": [],
- },
- {
- "rule_id": "global/override/.m.rule.invite_for_me",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.member",
- "_id": "_member",
- },
- {
- "kind": "event_match",
- "key": "content.membership",
- "pattern": "invite",
- "_id": "_invite_member",
- },
- {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
- ],
- "actions": ["notify", {"set_tweak": "sound", "value": "default"}],
- },
- {
- "rule_id": "global/override/.m.rule.contains_display_name",
- "conditions": [{"kind": "contains_display_name"}],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
- ],
- },
- {
- "rule_id": "global/override/.m.rule.tombstone",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.tombstone",
- "_id": "_tombstone",
- },
- {
- "kind": "event_match",
- "key": "state_key",
- "pattern": "",
- "_id": "_tombstone_statekey",
- },
- ],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
- ],
- },
- {
- "rule_id": "global/override/.m.rule.roomnotif",
- "conditions": [
- {
- "kind": "event_match",
- "key": "content.body",
- "pattern": "@room",
- "_id": "_roomnotif_content",
- },
- {
- "kind": "sender_notification_permission",
- "key": "room",
- "_id": "_roomnotif_pl",
- },
- ],
- "actions": [
- "notify",
- {"set_tweak": "highlight"},
- {"set_tweak": "sound", "value": "default"},
- ],
- },
- {
- "rule_id": "global/override/.m.rule.call",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.call.invite",
- "_id": "_call",
- }
- ],
- "actions": ["notify", {"set_tweak": "sound", "value": "ring"}],
- },
-]
-
-
BASE_APPEND_UNDERRIDE_RULES = [
{
"rule_id": "global/underride/.m.rule.call",
@@ -538,36 +386,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
]
-NEW_APPEND_UNDERRIDE_RULES = [
- {
- "rule_id": "global/underride/.m.rule.room_one_to_one",
- "conditions": [
- {"kind": "room_member_count", "is": "2", "_id": "member_count"},
- {
- "kind": "event_match",
- "key": "content.body",
- "pattern": "*",
- "_id": "body",
- },
- ],
- "actions": ["notify", {"set_tweak": "sound", "value": "default"}],
- },
- {
- "rule_id": "global/underride/.m.rule.message",
- "conditions": [
- {
- "kind": "event_match",
- "key": "content.body",
- "pattern": "*",
- "_id": "body",
- },
- ],
- "actions": ["notify"],
- "enabled": False,
- },
-]
-
-
BASE_RULE_IDS = set()
for r in BASE_APPEND_CONTENT_RULES:
@@ -589,26 +407,3 @@ for r in BASE_APPEND_UNDERRIDE_RULES:
r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
r["default"] = True
BASE_RULE_IDS.add(r["rule_id"])
-
-
-NEW_RULE_IDS = set()
-
-for r in BASE_APPEND_CONTENT_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["content"]
- r["default"] = True
- NEW_RULE_IDS.add(r["rule_id"])
-
-for r in BASE_PREPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- NEW_RULE_IDS.add(r["rule_id"])
-
-for r in NEW_APPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- NEW_RULE_IDS.add(r["rule_id"])
-
-for r in NEW_APPEND_UNDERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
- r["default"] = True
- NEW_RULE_IDS.add(r["rule_id"])
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index dadfc57413..3df8452eec 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -455,7 +455,7 @@ class Mailer:
}
the_events = await filter_events_for_client(
- self.storage, user_id, results["events_before"]
+ self.storage, user_id, results.events_before
)
the_events.append(notif_event)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index d844fbb3b3..22b4606ae0 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -70,7 +70,7 @@ REQUIREMENTS = [
"pyasn1>=0.1.9",
"pyasn1-modules>=0.0.7",
"bcrypt>=3.1.0",
- "pillow>=4.3.0",
+ "pillow>=5.4.0",
"sortedcontainers>=1.4.4",
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
@@ -107,7 +107,7 @@ CONDITIONAL_REQUIREMENTS = {
# `systemd.journal.JournalHandler`, as is documented in
# `contrib/systemd/log_config.yaml`.
"systemd": ["systemd-python>=231"],
- "url_preview": ["lxml>=3.5.0"],
+ "url_preview": ["lxml>=4.2.0"],
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 0f08372694..a72dad7464 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -52,8 +52,8 @@ class SlavedEventStore(
EventPushActionsWorkerStore,
StreamWorkerStore,
StateGroupWorkerStore,
- EventsWorkerStore,
SignatureWorkerStore,
+ EventsWorkerStore,
UserErasureWorkerStore,
RelationsWorkerStore,
BaseSlavedStore,
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 465e06772b..9be9e33c8e 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -41,7 +41,9 @@ from synapse.rest.admin.event_reports import (
EventReportsRestServlet,
)
from synapse.rest.admin.federation import (
- DestinationsRestServlet,
+ DestinationMembershipRestServlet,
+ DestinationResetConnectionRestServlet,
+ DestinationRestServlet,
ListDestinationsRestServlet,
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
@@ -267,7 +269,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ListRegistrationTokensRestServlet(hs).register(http_server)
NewRegistrationTokenRestServlet(hs).register(http_server)
RegistrationTokenRestServlet(hs).register(http_server)
- DestinationsRestServlet(hs).register(http_server)
+ DestinationMembershipRestServlet(hs).register(http_server)
+ DestinationResetConnectionRestServlet(hs).register(http_server)
+ DestinationRestServlet(hs).register(http_server)
ListDestinationsRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 8cd3fa189e..d162e0081e 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -16,6 +16,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
@@ -90,7 +91,7 @@ class ListDestinationsRestServlet(RestServlet):
return HTTPStatus.OK, response
-class DestinationsRestServlet(RestServlet):
+class DestinationRestServlet(RestServlet):
"""Get details of a destination.
This needs user to have administrator access in Synapse.
@@ -145,3 +146,100 @@ class DestinationsRestServlet(RestServlet):
}
return HTTPStatus.OK, response
+
+
+class DestinationMembershipRestServlet(RestServlet):
+ """Get list of rooms of a destination.
+ This needs user to have administrator access in Synapse.
+
+ GET /_synapse/admin/v1/federation/destinations/<destination>/rooms?from=0&limit=10
+
+ returns:
+ 200 OK with a list of rooms if success otherwise an error.
+
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ """
+
+ PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)/rooms$")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastore()
+
+ async def on_GET(
+ self, request: SynapseRequest, destination: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ if not await self._store.is_destination_known(destination):
+ raise NotFoundError("Unknown destination")
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+
+ rooms, total = await self._store.get_destination_rooms_paginate(
+ destination, start, limit, direction
+ )
+ response = {"rooms": rooms, "total": total}
+ if (start + limit) < total:
+ response["next_token"] = str(start + len(rooms))
+
+ return HTTPStatus.OK, response
+
+
+class DestinationResetConnectionRestServlet(RestServlet):
+ """Reset destinations' connection timeouts and wake it up.
+ This needs user to have administrator access in Synapse.
+
+ POST /_synapse/admin/v1/federation/destinations/<destination>/reset_connection
+ {}
+
+ returns:
+ 200 OK otherwise an error.
+ """
+
+ PATTERNS = admin_patterns(
+ "/federation/destinations/(?P<destination>[^/]+)/reset_connection$"
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastore()
+ self._authenticator = Authenticator(hs)
+
+ async def on_POST(
+ self, request: SynapseRequest, destination: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ if not await self._store.is_destination_known(destination):
+ raise NotFoundError("Unknown destination")
+
+ retry_timings = await self._store.get_destination_retry_timings(destination)
+ if not (retry_timings and retry_timings.retry_last_ts):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "The retry timing does not need to be reset for this destination.",
+ )
+
+ # reset timings and wake up
+ await self._authenticator.reset_retry_timings(destination)
+
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index efe25fe7eb..5b706efbcf 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = await self.room_context_handler.get_event_context(
+ event_context = await self.room_context_handler.get_event_context(
requester,
room_id,
event_id,
@@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet):
use_admin_priviledge=True,
)
- if not results:
+ if not event_context:
raise SynapseError(
HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
)
time_now = self.clock.time_msec()
- aggregations = results.pop("aggregations", None)
- results["events_before"] = self._event_serializer.serialize_events(
- results["events_before"], time_now, bundle_aggregations=aggregations
- )
- results["event"] = self._event_serializer.serialize_event(
- results["event"], time_now, bundle_aggregations=aggregations
- )
- results["events_after"] = self._event_serializer.serialize_events(
- results["events_after"], time_now, bundle_aggregations=aggregations
- )
- results["state"] = self._event_serializer.serialize_events(
- results["state"], time_now
- )
+ results = {
+ "events_before": self._event_serializer.serialize_events(
+ event_context.events_before,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "event": self._event_serializer.serialize_event(
+ event_context.event,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "events_after": self._event_serializer.serialize_events(
+ event_context.events_after,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "state": self._event_serializer.serialize_events(
+ event_context.state, time_now
+ ),
+ "start": event_context.start,
+ "end": event_context.end,
+ }
return HTTPStatus.OK, results
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index d1badbdf3b..58b8adbd32 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -66,7 +66,7 @@ class AccountDataServlet(RestServlet):
raise AuthError(403, "Cannot get account data for other users.")
event = await self.store.get_global_account_data_by_type_for_user(
- account_data_type, user_id
+ user_id, account_data_type
)
if event is None:
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 6f796d5e50..8fe75bd750 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -29,7 +29,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
+from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
@@ -61,10 +61,6 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
- self._users_new_default_push_rules = (
- hs.config.server.users_new_default_push_rules
- )
-
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -217,12 +213,7 @@ class PushRuleRestServlet(RestServlet):
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
- if user_id in self._users_new_default_push_rules:
- rule_ids = NEW_RULE_IDS
- else:
- rule_ids = BASE_RULE_IDS
-
- if namespaced_rule_id not in rule_ids:
+ if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 8b56c76aed..e3492f9f93 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -339,12 +339,19 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
+ self.inhibit_user_in_use_error = (
+ hs.config.registration.inhibit_user_in_use_error
+ )
+
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
+ if self.inhibit_user_in_use_error:
+ return 200, {"available": True}
+
ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@@ -418,10 +425,14 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self.password_auth_provider = hs.get_password_auth_provider()
self._registration_enabled = self.hs.config.registration.enable_registration
self._refresh_tokens_enabled = (
hs.config.registration.refreshable_access_token_lifetime is not None
)
+ self._inhibit_user_in_use_error = (
+ hs.config.registration.inhibit_user_in_use_error
+ )
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -564,6 +575,7 @@ class RegisterRestServlet(RestServlet):
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
+ inhibit_user_in_use_error=self._inhibit_user_in_use_error,
)
# Check if the user-interactive authentication flows are complete, if
@@ -627,7 +639,16 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = params.get("username", None)
+ desired_username = await (
+ self.password_auth_provider.get_username_for_registration(
+ auth_result,
+ params,
+ )
+ )
+
+ if desired_username is None:
+ desired_username = params.get("username", None)
+
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 90bb9142a0..90355e44b2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = await self.room_context_handler.get_event_context(
+ event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
)
- if not results:
+ if not event_context:
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- aggregations = results.pop("aggregations", None)
- results["events_before"] = self._event_serializer.serialize_events(
- results["events_before"], time_now, bundle_aggregations=aggregations
- )
- results["event"] = self._event_serializer.serialize_event(
- results["event"], time_now, bundle_aggregations=aggregations
- )
- results["events_after"] = self._event_serializer.serialize_events(
- results["events_after"], time_now, bundle_aggregations=aggregations
- )
- results["state"] = self._event_serializer.serialize_events(
- results["state"], time_now
- )
+ results = {
+ "events_before": self._event_serializer.serialize_events(
+ event_context.events_before,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "event": self._event_serializer.serialize_event(
+ event_context.event,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "events_after": self._event_serializer.serialize_events(
+ event_context.events_after,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "state": self._event_serializer.serialize_events(
+ event_context.state, time_now
+ ),
+ "start": event_context.start,
+ "end": event_context.end,
+ }
return 200, results
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d20ae1421e..f9615da525 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
def serialize(
events: Iterable[EventBase],
- aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+ aggregations: Optional[Dict[str, BundledAggregations]] = None,
) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 30b067dd42..872a9e72e8 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -321,14 +321,33 @@ def _iterate_over_text(
def rebase_url(url: str, base: str) -> str:
- base_parts = list(urlparse.urlparse(base))
+ """
+ Resolves a potentially relative `url` against an absolute `base` URL.
+
+ For example:
+
+ >>> rebase_url("subpage", "https://example.com/foo/")
+ 'https://example.com/foo/subpage'
+ >>> rebase_url("sibling", "https://example.com/foo")
+ 'https://example.com/sibling'
+ >>> rebase_url("/bar", "https://example.com/foo/")
+ 'https://example.com/bar'
+ >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
+ 'https://alice.com/a'
+ """
+ base_parts = urlparse.urlparse(base)
+ # Convert the parsed URL to a list for (potential) modification.
url_parts = list(urlparse.urlparse(url))
- if not url_parts[0]: # fix up schema
- url_parts[0] = base_parts[0] or "http"
- if not url_parts[1]: # fix up hostname
- url_parts[1] = base_parts[1]
+ # Add a scheme, if one does not exist.
+ if not url_parts[0]:
+ url_parts[0] = base_parts.scheme or "http"
+ # Fix up the hostname, if this is not a data URL.
+ if url_parts[0] != "data" and not url_parts[1]:
+ url_parts[1] = base_parts.netloc
+ # If the path does not start with a /, nest it under the base path's last
+ # directory.
if not url_parts[2].startswith("/"):
- url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+ url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
return urlparse.urlunparse(url_parts)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e8881bc870..efd84ced8f 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -21,8 +21,9 @@ import re
import shutil
import sys
import traceback
-from typing import TYPE_CHECKING, Iterable, Optional, Tuple
+from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
from urllib import parse as urlparse
+from urllib.request import urlopen
import attr
@@ -71,6 +72,17 @@ IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY
@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DownloadResult:
+ length: int
+ uri: str
+ response_code: int
+ media_type: str
+ download_name: Optional[str]
+ expires: int
+ etag: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class MediaInfo:
"""
Information parsed from downloading media being previewed.
@@ -256,7 +268,7 @@ class PreviewUrlResource(DirectServeJsonResource):
if oembed_url:
url_to_download = oembed_url
- media_info = await self._download_url(url_to_download, user)
+ media_info = await self._handle_url(url_to_download, user)
logger.debug("got media_info of '%s'", media_info)
@@ -297,7 +309,9 @@ class PreviewUrlResource(DirectServeJsonResource):
oembed_url = self._oembed.autodiscover_from_html(tree)
og_from_oembed: JsonDict = {}
if oembed_url:
- oembed_info = await self._download_url(oembed_url, user)
+ oembed_info = await self._handle_url(
+ oembed_url, user, allow_data_urls=True
+ )
(
og_from_oembed,
author_name,
@@ -367,7 +381,135 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
- async def _download_url(self, url: str, user: UserID) -> MediaInfo:
+ async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult:
+ """
+ Fetches a remote URL and parses the headers.
+
+ Args:
+ url: The URL to fetch.
+ output_stream: The stream to write the content to.
+
+ Returns:
+ A tuple of:
+ Media length, URL downloaded, the HTTP response code,
+ the media type, the downloaded file name, the number of
+ milliseconds the result is valid for, the etag header.
+ """
+
+ try:
+ logger.debug("Trying to get preview for url '%s'", url)
+ length, headers, uri, code = await self.client.get_file(
+ url,
+ output_stream=output_stream,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url, e)
+
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
+
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+
+ return DownloadResult(
+ length, uri, code, media_type, download_name, expires, etag
+ )
+
+ async def _parse_data_url(
+ self, url: str, output_stream: BinaryIO
+ ) -> DownloadResult:
+ """
+ Parses a data: URL.
+
+ Args:
+ url: The URL to parse.
+ output_stream: The stream to write the content to.
+
+ Returns:
+ A tuple of:
+ Media length, URL downloaded, the HTTP response code,
+ the media type, the downloaded file name, the number of
+ milliseconds the result is valid for, the etag header.
+ """
+
+ try:
+ logger.debug("Trying to parse data url '%s'", url)
+ with urlopen(url) as url_info:
+ # TODO Can this be more efficient.
+ output_stream.write(url_info.read())
+ except Exception as e:
+ logger.warning("Error parsing data: URL %s: %r", url, e)
+
+ raise SynapseError(
+ 500,
+ "Failed to parse data URL: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+
+ return DownloadResult(
+ # Read back the length that has been written.
+ length=output_stream.tell(),
+ uri=url,
+ # If it was parsed, consider this a 200 OK.
+ response_code=200,
+ # urlopen shoves the media-type from the data URL into the content type
+ # header object.
+ media_type=url_info.headers.get_content_type(),
+ # Some features are not supported by data: URLs.
+ download_name=None,
+ expires=ONE_HOUR,
+ etag=None,
+ )
+
+ async def _handle_url(
+ self, url: str, user: UserID, allow_data_urls: bool = False
+ ) -> MediaInfo:
+ """
+ Fetches content from a URL and parses the result to generate a MediaInfo.
+
+ It uses the media storage provider to persist the fetched content and
+ stores the mapping into the database.
+
+ Args:
+ url: The URL to fetch.
+ user: The user who ahs requested this URL.
+ allow_data_urls: True if data URLs should be allowed.
+
+ Returns:
+ A MediaInfo object describing the fetched content.
+ """
+
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -377,61 +519,27 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- try:
- logger.debug("Trying to get preview for url '%s'", url)
- length, headers, uri, code = await self.client.get_file(
- url,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url, e)
-
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
+ if url.startswith("data:"):
+ if not allow_data_urls:
+ raise SynapseError(
+ 500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN
+ )
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
+ download_result = await self._parse_data_url(url, f)
else:
- media_type = "application/octet-stream"
+ download_result = await self._download_url(url, f)
- download_name = get_filename_from_headers(headers)
-
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- expires = ONE_HOUR
- etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ await finish()
try:
time_now_ms = self.clock.time_msec()
await self.store.store_local_media(
media_id=file_id,
- media_type=media_type,
+ media_type=download_result.media_type,
time_now_ms=time_now_ms,
- upload_name=download_name,
- media_length=length,
+ upload_name=download_result.download_name,
+ media_length=download_result.length,
user_id=user,
url_cache=url,
)
@@ -444,16 +552,16 @@ class PreviewUrlResource(DirectServeJsonResource):
raise
return MediaInfo(
- media_type=media_type,
- media_length=length,
- download_name=download_name,
+ media_type=download_result.media_type,
+ media_length=download_result.length,
+ download_name=download_result.download_name,
created_ts_ms=time_now_ms,
filesystem_id=file_id,
filename=fname,
- uri=uri,
- response_code=code,
- expires=expires,
- etag=etag,
+ uri=download_result.uri,
+ response_code=download_result.response_code,
+ expires=download_result.expires,
+ etag=download_result.etag,
)
async def _precache_image_url(
@@ -474,8 +582,8 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- image_info = await self._download_url(
- rebase_url(og["og:image"], media_info.uri), user
+ image_info = await self._handle_url(
+ rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
)
if _is_media(image_info.media_type):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7967011afd..8df80664a2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,7 +57,7 @@ class SQLBaseStore(metaclass=ABCMeta):
pass
def _invalidate_state_caches(
- self, room_id: str, members_changed: Iterable[str]
+ self, room_id: str, members_changed: Collection[str]
) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -66,11 +66,16 @@ class SQLBaseStore(metaclass=ABCMeta):
room_id: Room where state changed
members_changed: The user_ids of members that have changed
"""
+ # If there were any membership changes, purge the appropriate caches.
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
+ if members_changed:
+ self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache(
+ "get_users_in_room_with_profiles", (room_id,)
+ )
- self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
- self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
+ # Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 57cc1d76e0..99802228c9 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -702,6 +702,7 @@ class DatabasePool:
func: Callable[..., R],
*args: Any,
db_autocommit: bool = False,
+ isolation_level: Optional[int] = None,
**kwargs: Any,
) -> R:
"""Starts a transaction on the database and runs a given function
@@ -724,6 +725,7 @@ class DatabasePool:
called multiple times if the transaction is retried, so must
correctly handle that case.
+ isolation_level: Set the server isolation level for this transaction.
args: positional args to pass to `func`
kwargs: named args to pass to `func`
@@ -746,6 +748,7 @@ class DatabasePool:
func,
*args,
db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
**kwargs,
)
@@ -763,6 +766,7 @@ class DatabasePool:
func: Callable[..., R],
*args: Any,
db_autocommit: bool = False,
+ isolation_level: Optional[int] = None,
**kwargs: Any,
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
@@ -775,6 +779,7 @@ class DatabasePool:
db_autocommit: Whether to run the function in "autocommit" mode,
i.e. outside of a transaction. This is useful for transaction
that are only a single query. Currently only affects postgres.
+ isolation_level: Set the server isolation level for this transaction.
kwargs: named args to pass to `func`
Returns:
@@ -834,6 +839,10 @@ class DatabasePool:
try:
if db_autocommit:
self.engine.attempt_to_set_autocommit(conn, True)
+ if isolation_level is not None:
+ self.engine.attempt_to_set_isolation_level(
+ conn, isolation_level
+ )
db_conn = LoggingDatabaseConnection(
conn, self.engine, "runWithConnection"
@@ -842,6 +851,8 @@ class DatabasePool:
finally:
if db_autocommit:
self.engine.attempt_to_set_autocommit(conn, False)
+ if isolation_level:
+ self.engine.attempt_to_set_isolation_level(conn, None)
return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef475e18c7..52146aacc8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -26,6 +26,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -105,6 +106,11 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"AccountDataAndTagsChangeCache", account_max
)
+ self.db_pool.updates.register_background_update_handler(
+ "delete_account_data_for_deactivated_users",
+ self._delete_account_data_for_deactivated_users,
+ )
+
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
@@ -158,9 +164,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cached(num_args=2, max_entries=5000)
+ @cached(num_args=2, max_entries=5000, tree=True)
async def get_global_account_data_by_type_for_user(
- self, data_type: str, user_id: str
+ self, user_id: str, data_type: str
) -> Optional[JsonDict]:
"""
Returns:
@@ -179,7 +185,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
else:
return None
- @cached(num_args=2)
+ @cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
@@ -210,7 +216,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- @cached(num_args=3, max_entries=5000)
+ @cached(num_args=3, max_entries=5000, tree=True)
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
@@ -392,7 +398,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id)
+ (row.user_id, row.data_type)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
@@ -476,7 +482,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
- (account_data_type, user_id)
+ (user_id, account_data_type)
)
return self._account_data_id_gen.get_current_token()
@@ -546,6 +552,123 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ async def purge_account_data_for_user(self, user_id: str) -> None:
+ """
+ Removes ALL the account data for a user.
+ Intended to be used upon user deactivation.
+
+ Also purges the user from the ignored_users cache table
+ and the push_rules cache tables.
+ """
+
+ await self.db_pool.runInteraction(
+ "purge_account_data_for_user_txn",
+ self._purge_account_data_for_user_txn,
+ user_id,
+ )
+
+ def _purge_account_data_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
+ """
+ See `purge_account_data_for_user`.
+ """
+ # Purge from the primary account_data tables.
+ self.db_pool.simple_delete_txn(
+ txn, table="account_data", keyvalues={"user_id": user_id}
+ )
+
+ self.db_pool.simple_delete_txn(
+ txn, table="room_account_data", keyvalues={"user_id": user_id}
+ )
+
+ # Purge from ignored_users where this user is the ignorer.
+ # N.B. We don't purge where this user is the ignoree, because that
+ # interferes with other users' account data.
+ # It's also not this user's data to delete!
+ self.db_pool.simple_delete_txn(
+ txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
+ )
+
+ # Remove the push rules
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_enable", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_stream", keyvalues={"user_id": user_id}
+ )
+
+ # Invalidate caches as appropriate
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room_and_type, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_global_account_data_by_type_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room, (user_id,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_push_rules_enabled_for_user, (user_id,)
+ )
+ # This user might be contained in the ignored_by cache for other users,
+ # so we have to invalidate it all.
+ self._invalidate_all_cache_and_stream(txn, self.ignored_by)
+
+ async def _delete_account_data_for_deactivated_users(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """
+ Retroactively purges account data for users that have already been deactivated.
+ Gets run as a background update caused by a schema delta.
+ """
+
+ last_user: str = progress.get("last_user", "")
+
+ def _delete_account_data_for_deactivated_users_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ sql = """
+ SELECT name FROM users
+ WHERE deactivated = ? and name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (1, last_user, batch_size))
+ users = [row[0] for row in txn]
+
+ for user in users:
+ self._purge_account_data_for_user_txn(txn, user_id=user)
+
+ if users:
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ "delete_account_data_for_deactivated_users",
+ {"last_user": users[-1]},
+ )
+
+ return len(users)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_delete_account_data_for_deactivated_users",
+ _delete_account_data_for_deactivated_users_txn,
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "delete_account_data_for_deactivated_users"
+ )
+
+ return number_deleted
+
class AccountDataStore(AccountDataWorkerStore):
pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 92c95a41d7..304814af5d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict],
+ to_device_messages: List[JsonDict],
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
service: The service who the transaction is for.
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
+ to_device_messages: A list of to-device messages to put in the transaction.
Returns:
A new transaction.
@@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
(service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(
- service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+ service=service,
+ id=new_txn_id,
+ events=events,
+ ephemeral=ephemeral,
+ to_device_messages=to_device_messages,
)
return await self.db_pool.runInteraction(
@@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(
- service=service, id=entry["txn_id"], events=events, ephemeral=[]
+ service=service,
+ id=entry["txn_id"],
+ events=events,
+ ephemeral=[],
+ to_device_messages=[],
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -384,14 +394,14 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = await self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids, get_prev_content=True)
return upper_bound, events
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence"):
+ if type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
- async def set_type_stream_id_for_appservice(
+ async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence"):
+ if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
- def set_type_stream_id_for_appservice_txn(txn):
+ def set_appservice_stream_type_pos_txn(txn):
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
@@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
)
await self.db_pool.runInteraction(
- "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+ "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 0024348067..c428dd5596 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ def _invalidate_state_caches_and_stream(
+ self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str]
+ ) -> None:
"""Special case invalidation of caches based on current state.
We special case this so that we can batch the cache invalidations into a
@@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
Args:
txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4eca97189b..8801b7b2dd 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
- async def get_new_messages_for_device(
+ async def get_messages_for_user_devices(
+ self,
+ user_ids: Collection[str],
+ from_stream_id: int,
+ to_stream_id: int,
+ ) -> Dict[Tuple[str, str], List[JsonDict]]:
+ """
+ Retrieve to-device messages for a given set of users.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
+ Args:
+ user_ids: The users to retrieve to-device messages for.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+ Returns:
+ A dictionary of (user id, device id) -> list of to-device messages.
+ """
+ # We expect the stream ID returned by _get_device_messages to always
+ # be to_stream_id. So, no need to return it from this function.
+ (
+ user_id_device_id_to_messages,
+ last_processed_stream_id,
+ ) = await self._get_device_messages(
+ user_ids=user_ids,
+ from_stream_id=from_stream_id,
+ to_stream_id=to_stream_id,
+ )
+
+ assert (
+ last_processed_stream_id == to_stream_id
+ ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
+
+ return user_id_device_id_to_messages
+
+ async def get_messages_for_device(
self,
user_id: str,
- device_id: Optional[str],
- last_stream_id: int,
- current_stream_id: int,
+ device_id: str,
+ from_stream_id: int,
+ to_stream_id: int,
limit: int = 100,
- ) -> Tuple[List[dict], int]:
+ ) -> Tuple[List[JsonDict], int]:
"""
+ Retrieve to-device messages for a single user device.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
Args:
- user_id: The recipient user_id.
- device_id: The recipient device_id.
- last_stream_id: The last stream ID checked.
- current_stream_id: The current position of the to device
- message stream.
- limit: The maximum number of messages to retrieve.
+ user_id: The ID of the user to retrieve messages for.
+ device_id: The ID of the device to retrieve to-device messages for.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+ limit: A limit on the number of to-device messages returned.
Returns:
A tuple containing:
- * A list of messages for the device.
- * The max stream token of these messages. There may be more to retrieve
- if the given limit was reached.
+ * A list of to-device messages within the given stream id range intended for
+ the given user / device combo.
+ * The last-processed stream ID. Subsequent calls of this function with the
+ same device should pass this value as 'from_stream_id'.
"""
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_stream_id
+ (
+ user_id_device_id_to_messages,
+ last_processed_stream_id,
+ ) = await self._get_device_messages(
+ user_ids=[user_id],
+ device_id=device_id,
+ from_stream_id=from_stream_id,
+ to_stream_id=to_stream_id,
+ limit=limit,
)
- if not has_changed:
- return [], current_stream_id
- def get_new_messages_for_device_txn(txn):
- sql = (
- "SELECT stream_id, message_json FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
+ if not user_id_device_id_to_messages:
+ # There were no messages!
+ return [], to_stream_id
+
+ # Extract the messages, no need to return the user and device ID again
+ to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+
+ return to_device_messages, last_processed_stream_id
+
+ async def _get_device_messages(
+ self,
+ user_ids: Collection[str],
+ from_stream_id: int,
+ to_stream_id: int,
+ device_id: Optional[str] = None,
+ limit: Optional[int] = None,
+ ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+ """
+ Retrieve pending to-device messages for a collection of user devices.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
+ Note that a stream ID can be shared by multiple copies of the same message with
+ different recipient devices. Stream IDs are only unique in the context of a single
+ user ID / device ID pair. Thus, applying a limit (of messages to return) when working
+ with a sliding window of stream IDs is only possible when querying messages of a
+ single user device.
+
+ Finally, note that device IDs are not unique across users.
+
+ Args:
+ user_ids: The user IDs to filter device messages by.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+ device_id: A device ID to query to-device messages for. If not provided, to-device
+ messages from all device IDs for the given user IDs will be queried. May not be
+ provided if `user_ids` contains more than one entry.
+ limit: The maximum number of to-device messages to return. Can only be used when
+ passing a single user ID / device ID tuple.
+
+ Returns:
+ A tuple containing:
+ * A dict of (user_id, device_id) -> list of to-device messages
+ * The last-processed stream ID. If this is less than `to_stream_id`, then
+ there may be more messages to retrieve. If `limit` is not set, then this
+ is always equal to 'to_stream_id'.
+ """
+ if not user_ids:
+ logger.warning("No users provided upon querying for device IDs")
+ return {}, to_stream_id
+
+ # Prevent a query for one user's device also retrieving another user's device with
+ # the same device ID (device IDs are not unique across users).
+ if len(user_ids) > 1 and device_id is not None:
+ raise AssertionError(
+ "Programming error: 'device_id' cannot be supplied to "
+ "_get_device_messages when >1 user_id has been provided"
)
- txn.execute(
- sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+
+ # A limit can only be applied when querying for a single user ID / device ID tuple.
+ # See the docstring of this function for more details.
+ if limit is not None and device_id is None:
+ raise AssertionError(
+ "Programming error: _get_device_messages was passed 'limit' "
+ "without a specific user_id/device_id"
)
- messages = []
- stream_pos = current_stream_id
+ user_ids_to_query: Set[str] = set()
+ device_ids_to_query: Set[str] = set()
+
+ # Note that a device ID could be an empty str
+ if device_id is not None:
+ # If a device ID was passed, use it to filter results.
+ # Otherwise, device IDs will be derived from the given collection of user IDs.
+ device_ids_to_query.add(device_id)
+
+ # Determine which users have devices with pending messages
+ for user_id in user_ids:
+ if self._device_inbox_stream_cache.has_entity_changed(
+ user_id, from_stream_id
+ ):
+ # This user has new messages sent to them. Query messages for them
+ user_ids_to_query.add(user_id)
+
+ def get_device_messages_txn(txn: LoggingTransaction):
+ # Build a query to select messages from any of the given devices that
+ # are between the given stream id bounds.
+
+ # If a list of device IDs was not provided, retrieve all devices IDs
+ # for the given users. We explicitly do not query hidden devices, as
+ # hidden devices should not receive to-device messages.
+ # Note that this is more efficient than just dropping `device_id` from the query,
+ # since device_inbox has an index on `(user_id, device_id, stream_id)`
+ if not device_ids_to_query:
+ user_device_dicts = self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("device_id",),
+ )
- for row in txn:
- stream_pos = row[0]
- messages.append(db_to_json(row[1]))
+ device_ids_to_query.update(
+ {row["device_id"] for row in user_device_dicts}
+ )
- # If the limit was not reached we know that there's no more data for this
- # user/device pair up to current_stream_id.
- if len(messages) < limit:
- stream_pos = current_stream_id
+ if not device_ids_to_query:
+ # We've ended up with no devices to query.
+ return {}, to_stream_id
- return messages, stream_pos
+ # We include both user IDs and device IDs in this query, as we have an index
+ # (device_inbox_user_stream_id) for them.
+ user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
+ self.database_engine, "user_id", user_ids_to_query
+ )
+ (
+ device_id_many_clause_sql,
+ device_id_many_clause_args,
+ ) = make_in_list_sql_clause(
+ self.database_engine, "device_id", device_ids_to_query
+ )
+
+ sql = f"""
+ SELECT stream_id, user_id, device_id, message_json FROM device_inbox
+ WHERE {user_id_many_clause_sql}
+ AND {device_id_many_clause_sql}
+ AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ sql_args = (
+ *user_id_many_clause_args,
+ *device_id_many_clause_args,
+ from_stream_id,
+ to_stream_id,
+ )
+
+ # If a limit was provided, limit the data retrieved from the database
+ if limit is not None:
+ sql += "LIMIT ?"
+ sql_args += (limit,)
+
+ txn.execute(sql, sql_args)
+
+ # Create and fill a dictionary of (user ID, device ID) -> list of messages
+ # intended for each device.
+ last_processed_stream_pos = to_stream_id
+ recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
+ for row in txn:
+ last_processed_stream_pos = row[0]
+ recipient_user_id = row[1]
+ recipient_device_id = row[2]
+ message_dict = db_to_json(row[3])
+
+ # Store the device details
+ recipient_device_to_messages.setdefault(
+ (recipient_user_id, recipient_device_id), []
+ ).append(message_dict)
+
+ if limit is not None and txn.rowcount == limit:
+ # We ended up bumping up against the message limit. There may be more messages
+ # to retrieve. Return what we have, as well as the last stream position that
+ # was processed.
+ #
+ # The caller is expected to set this as the lower (exclusive) bound
+ # for the next query of this device.
+ return recipient_device_to_messages, last_processed_stream_pos
+
+ # The limit was not reached, thus we know that recipient_device_to_messages
+ # contains all to-device messages for the given device and stream id range.
+ #
+ # We return to_stream_id, which the caller should then provide as the lower
+ # (exclusive) bound on the next query of this device.
+ return recipient_device_to_messages, to_stream_id
return await self.db_pool.runInteraction(
- "get_new_messages_for_device", get_new_messages_for_device_txn
+ "get_device_messages", get_device_messages_txn
)
@trace
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a556f17dac..ca71f073fc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -65,7 +65,7 @@ class _NoChainCoverIndex(Exception):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
-class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1ae1ebe108..b7554154ac 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1389,6 +1389,8 @@ class PersistEventsStore:
"received_ts",
"sender",
"contains_url",
+ "state_key",
+ "rejection_reason",
),
values=(
(
@@ -1405,8 +1407,10 @@ class PersistEventsStore:
self._clock.time_msec(),
event.sender,
"url" in event.content and isinstance(event.content["url"], str),
+ event.get_state_key(),
+ context.rejected or None,
)
- for event, _ in events_and_contexts
+ for event, context in events_and_contexts
),
)
@@ -1456,6 +1460,7 @@ class PersistEventsStore:
for event, context in events_and_contexts:
if context.rejected:
# Insert the event_id into the rejections table
+ # (events.rejection_reason has already been done)
self._store_rejections_txn(txn, event.event_id, context.rejected)
to_remove.add(event)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..e87a8fb85d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -390,7 +390,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_search",
"events",
"group_rooms",
- "public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e01c94930a..92539f5d41 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _load_rules(rawrules, enabled_map, use_new_defaults=False):
+def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -52,7 +52,7 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist, use_new_defaults))
+ rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule["rule_id"]
@@ -112,10 +112,6 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- self._users_new_default_push_rules = (
- hs.config.server.users_new_default_push_rules
- )
-
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
@@ -145,9 +141,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- use_new_defaults = user_id in self._users_new_default_push_rules
-
- return _load_rules(rows, enabled_map, use_new_defaults)
+ return _load_rules(rows, enabled_map)
@cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
@@ -206,13 +200,7 @@ class PushRulesWorkerStore(
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
- use_new_defaults = user_id in self._users_new_default_push_rules
-
- results[user_id] = _load_rules(
- rules,
- enabled_map_by_user.get(user_id, {}),
- use_new_defaults,
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..37468a5183 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterable,
- List,
- Optional,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
import attr
from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ latest_event: EventBase
+ count: int
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -60,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
@cached(tree=True)
@@ -585,7 +599,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +630,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
- aggregations: Dict[str, Any] = {}
+ aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+ aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
+ aggregations.references = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)
if edit:
- aggregations[RelationTypes.REPLACE] = edit
+ aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
@@ -644,11 +658,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id
)
if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- "latest_event": latest_thread_event,
- "count": thread_count,
- "current_user_participated": participated,
- }
+ aggregations.thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ current_user_participated=participated,
+ )
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@@ -657,7 +671,7 @@ class RelationsWorkerStore(SQLBaseStore):
self,
events: Iterable[EventBase],
user_id: str,
- ) -> Dict[str, Dict[str, Any]]:
+ ) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
@@ -668,15 +682,12 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
- # If bundled aggregations are disabled, nothing to do.
- if not self._msc1849_enabled:
- return {}
# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result is not None:
+ if event_result:
results[event.event_id] = event_result
return results
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 3201623fe4..0518b8b910 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable, List, Tuple
+from typing import Collection, Dict, List, Tuple
from unpaddedbase64 import encode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.types import Cursor
+from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage.databases.main.events_worker import (
+ EventRedactBehaviour,
+ EventsWorkerStore,
+)
from synapse.util.caches.descriptors import cached, cachedList
-class SignatureWorkerStore(SQLBaseStore):
+class SignatureWorkerStore(EventsWorkerStore):
@cached()
def get_event_reference_hash(self, event_id):
# This is a dummy function to allow get_event_reference_hashes
@@ -32,7 +35,7 @@ class SignatureWorkerStore(SQLBaseStore):
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
async def get_event_reference_hashes(
- self, event_ids: Iterable[str]
+ self, event_ids: Collection[str]
) -> Dict[str, Dict[str, bytes]]:
"""Get all hashes for given events.
@@ -41,18 +44,27 @@ class SignatureWorkerStore(SQLBaseStore):
Returns:
A mapping of event ID to a mapping of algorithm to hash.
+ Returns an empty dict for a given event id if that event is unknown.
"""
+ events = await self.get_events(
+ event_ids,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ allow_rejected=True,
+ )
- def f(txn):
- return {
- event_id: self._get_event_reference_hashes_txn(txn, event_id)
- for event_id in event_ids
- }
+ hashes: Dict[str, Dict[str, bytes]] = {}
+ for event_id in event_ids:
+ event = events.get(event_id)
+ if event is None:
+ hashes[event_id] = {}
+ else:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ hashes[event_id] = {ref_alg: ref_hash_bytes}
- return await self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return hashes
async def add_event_hashes(
- self, event_ids: Iterable[str]
+ self, event_ids: Collection[str]
) -> List[Tuple[str, Dict[str, str]]]:
"""
@@ -70,24 +82,6 @@ class SignatureWorkerStore(SQLBaseStore):
return list(encoded_hashes.items())
- def _get_event_reference_hashes_txn(
- self, txn: Cursor, event_id: str
- ) -> Dict[str, bytes]:
- """Get all the hashes for a given PDU.
- Args:
- txn:
- event_id: Id for the Event.
- Returns:
- A mapping of algorithm -> hash.
- """
- query = (
- "SELECT algorithm, hash"
- " FROM event_reference_hashes"
- " WHERE event_id = ?"
- )
- txn.execute(query, (event_id,))
- return {k: v for k, v in txn}
-
class SignatureStore(SignatureWorkerStore):
"""Persistence for event signatures and hashes"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
stream_ordering: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+ events_before: List[EventBase]
+ events_after: List[EventBase]
+ start: RoomStreamToken
+ end: RoomStreamToken
+
+
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
- ) -> dict:
+ ) -> _EventsAround:
"""Retrieve events and pagination tokens around a given event in a
room.
"""
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
list(results["after"]["event_ids"]), get_prev_content=True
)
- return {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
+ return _EventsAround(
+ events_before=events_before,
+ events_after=events_after,
+ start=results["before"]["token"],
+ end=results["after"]["token"],
+ )
def _get_events_around_txn(
self,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 4b78b4d098..ba79e19f7f 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -561,6 +561,54 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
"get_destinations_paginate_txn", get_destinations_paginate_txn
)
+ async def get_destination_rooms_paginate(
+ self, destination: str, start: int, limit: int, direction: str = "f"
+ ) -> Tuple[List[JsonDict], int]:
+ """Function to retrieve a paginated list of destination's rooms.
+ This will return a json list of rooms and the
+ total number of rooms.
+
+ Args:
+ destination: the destination to query
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ direction: sort ascending or descending by room_id
+ Returns:
+ A tuple of a dict of rooms and a count of total rooms.
+ """
+
+ def get_destination_rooms_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT COUNT(*) as total_rooms
+ FROM destination_rooms
+ WHERE destination = ?
+ """
+ txn.execute(sql, [destination])
+ count = cast(Tuple[int], txn.fetchone())[0]
+
+ rooms = self.db_pool.simple_select_list_paginate_txn(
+ txn=txn,
+ table="destination_rooms",
+ orderby="room_id",
+ start=start,
+ limit=limit,
+ retcols=("room_id", "stream_ordering"),
+ order_direction=order,
+ )
+ return rooms, count
+
+ return await self.db_pool.runInteraction(
+ "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn
+ )
+
async def is_destination_known(self, destination: str) -> bool:
"""Check if a destination is known to the server."""
result = await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 20cd63c330..143cd98ca2 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,11 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
-from typing import Generic, TypeVar
+from enum import IntEnum
+from typing import Generic, Optional, TypeVar
from synapse.storage.types import Connection
+class IsolationLevel(IntEnum):
+ READ_COMMITTED: int = 1
+ REPEATABLE_READ: int = 2
+ SERIALIZABLE: int = 3
+
+
class IncorrectDatabaseSetup(RuntimeError):
pass
@@ -109,3 +116,13 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
commit/rollback the connections.
"""
...
+
+ @abc.abstractmethod
+ def attempt_to_set_isolation_level(
+ self, conn: Connection, isolation_level: Optional[int]
+ ):
+ """Attempt to set the connections isolation level.
+
+ Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 30f948a0f7..808342fafb 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,8 +13,13 @@
# limitations under the License.
import logging
+from typing import Mapping, Optional
-from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.engines._base import (
+ BaseDatabaseEngine,
+ IncorrectDatabaseSetup,
+ IsolationLevel,
+)
from synapse.storage.types import Connection
logger = logging.getLogger(__name__)
@@ -34,6 +39,15 @@ class PostgresEngine(BaseDatabaseEngine):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
+ self.isolation_level_map: Mapping[int, int] = {
+ IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
+ IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+ IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
+ }
+ self.default_isolation_level = (
+ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
+ )
+
@property
def single_threaded(self) -> bool:
return False
@@ -46,8 +60,8 @@ class PostgresEngine(BaseDatabaseEngine):
self._version = db_conn.server_version
# Are we on a supported PostgreSQL version?
- if not allow_outdated_version and self._version < 90600:
- raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.")
+ if not allow_outdated_version and self._version < 100000:
+ raise RuntimeError("Synapse requires PostgreSQL 10 or above.")
with db_conn.cursor() as txn:
txn.execute("SHOW SERVER_ENCODING")
@@ -104,9 +118,7 @@ class PostgresEngine(BaseDatabaseEngine):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
- db_conn.set_isolation_level(
- self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
- )
+ db_conn.set_isolation_level(self.default_isolation_level)
# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
@@ -175,3 +187,12 @@ class PostgresEngine(BaseDatabaseEngine):
def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
return conn.set_session(autocommit=autocommit) # type: ignore
+
+ def attempt_to_set_isolation_level(
+ self, conn: Connection, isolation_level: Optional[int]
+ ):
+ if isolation_level is None:
+ isolation_level = self.default_isolation_level
+ else:
+ isolation_level = self.isolation_level_map[isolation_level]
+ return conn.set_isolation_level(isolation_level) # type: ignore
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 70d17d4f2c..6c19e55999 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -15,6 +15,7 @@ import platform
import struct
import threading
import typing
+from typing import Optional
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Connection
@@ -122,6 +123,12 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
# set the connection to autocommit mode.
pass
+ def attempt_to_set_isolation_level(
+ self, conn: Connection, isolation_level: Optional[int]
+ ):
+ # All transactions are SERIALIZABLE by default in sqllite
+ pass
+
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1823e18720..e3153d1a4a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -499,9 +499,12 @@ def _upgrade_existing_database(
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
- logger.info("Running script %s", relative_path)
- module.run_create(cur, database_engine) # type: ignore
- if not is_empty:
+ if hasattr(module, "run_create"):
+ logger.info("Running %s:run_create", relative_path)
+ module.run_create(cur, database_engine) # type: ignore
+
+ if not is_empty and hasattr(module, "run_upgrade"):
+ logger.info("Running %s:run_upgrade", relative_path)
module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 2a3d47185a..7b21c1b96d 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 67 # remember to update the list below when updating
+SCHEMA_VERSION = 68 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -53,11 +53,18 @@ Changes in SCHEMA_VERSION = 66:
Changes in SCHEMA_VERSION = 67:
- state_events.prev_state is no longer written to.
+
+Changes in SCHEMA_VERSION = 68:
+ - event_reference_hashes is no longer read.
+ - `events` has `state_key` and `rejection_reason` columns, which are populated for
+ new events.
"""
SCHEMA_COMPAT_VERSION = (
- 61 # 61: Remove unused tables `user_stats_historical` and `room_stats_historical`
+ # we now have `state_key` columns in both `events` and `state_events`, so
+ # now incompatible with synapses wth SCHEMA_VERSION < 66.
+ 66
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
diff --git a/synapse/storage/schema/main/delta/67/01drop_public_room_list_stream.sql b/synapse/storage/schema/main/delta/67/01drop_public_room_list_stream.sql
new file mode 100644
index 0000000000..1eb8de9907
--- /dev/null
+++ b/synapse/storage/schema/main/delta/67/01drop_public_room_list_stream.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this table is unused as of Synapse 1.41
+DROP TABLE public_room_list_stream;
+
diff --git a/synapse/storage/schema/main/delta/68/01event_columns.sql b/synapse/storage/schema/main/delta/68/01event_columns.sql
new file mode 100644
index 0000000000..7c072f972e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/01event_columns.sql
@@ -0,0 +1,26 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add new colums to the `events` table which will (one day) make the `state_events`
+-- and `rejections` tables redundant.
+
+ALTER TABLE events
+ -- if this event is a state event, its state key
+ ADD COLUMN state_key TEXT DEFAULT NULL;
+
+
+ALTER TABLE events
+ -- if this event was rejected, the reason it was rejected.
+ ADD COLUMN rejection_reason TEXT DEFAULT NULL;
diff --git a/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql
new file mode 100644
index 0000000000..bbf0af5311
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql
@@ -0,0 +1,21 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to track what to_device stream id that this application
+-- service has been caught up to.
+
+-- NULL indicates that this appservice has never received any to_device messages. This
+-- can be used, for example, to avoid sending a huge dump of messages at startup.
+ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;
\ No newline at end of file
diff --git a/synapse/storage/schema/main/delta/68/03_delete_account_data_for_deactivated_accounts.sql b/synapse/storage/schema/main/delta/68/03_delete_account_data_for_deactivated_accounts.sql
new file mode 100644
index 0000000000..e124933843
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/03_delete_account_data_for_deactivated_accounts.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We want to retroactively delete account data for users that were already
+-- deactivated.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6803, 'delete_account_data_for_deactivated_users', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index df8b2f1088..913448f0f9 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -74,21 +74,21 @@ class StateFilter:
@staticmethod
def all() -> "StateFilter":
- """Creates a filter that fetches everything.
+ """Returns a filter that fetches everything.
Returns:
- The new state filter.
+ The state filter.
"""
- return StateFilter(types=frozendict(), include_others=True)
+ return _ALL_STATE_FILTER
@staticmethod
def none() -> "StateFilter":
- """Creates a filter that fetches nothing.
+ """Returns a filter that fetches nothing.
Returns:
The new state filter.
"""
- return StateFilter(types=frozendict(), include_others=False)
+ return _NONE_STATE_FILTER
@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -527,6 +527,10 @@ class StateFilter:
)
+_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
+
+
class StateGroupStorage:
"""High level interface to fetching state for event."""
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 377c9a282a..1d6ec22191 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -81,13 +81,14 @@ class DeferredCache(Generic[KT, VT]):
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
- keylen: The length of the tuple used as the cache key. Ignored unless
- `tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
+ prune_unread_entries: If True, cache entries that haven't been read recently
+ will be evicted from the cache in the background. Set to False to
+ opt-out of this behaviour.
"""
cache_type = TreeCache if tree else dict
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 375cd443f1..df4fb156c2 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -254,9 +254,17 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2
Args:
+ orig:
+ max_entries:
num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
+ tree:
+ cache_context:
+ iterable:
+ prune_unread_entries: If True, cache entries that haven't been read recently
+ will be evicted from the cache in the background. Set to False to opt-out
+ of this behaviour.
"""
def __init__(
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 3f11a2f9dd..7548b38548 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -340,6 +340,12 @@ class LruCache(Generic[KT, VT]):
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
+
+ clock:
+
+ prune_unread_entries: If True, cache entries that haven't been read recently
+ will be evicted from the cache in the background. Set to False to
+ opt-out of this behaviour.
"""
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 17532059e9..1b970ce479 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -87,7 +87,7 @@ async def filter_events_for_client(
)
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST, user_id
+ user_id, AccountDataTypes.IGNORED_USER_LIST
)
ignore_list: FrozenSet[str] = frozenset()
|