diff --git a/synapse/__init__.py b/synapse/__init__.py
index 06b179a7e8..48ac38aec6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.47.0rc2"
+__version__ = "1.47.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 4486b3bc7d..f9f9467dc1 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -30,7 +30,8 @@ FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
-MEDIA_PREFIX = "/_matrix/media/r0"
+MEDIA_R0_PREFIX = "/_matrix/media/r0"
+MEDIA_V3_PREFIX = "/_matrix/media/v3"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 573bb487b2..807ee3d46e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None:
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
- def handle_sighup(*args: Any, **kwargs: Any) -> None:
+ async def handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 46f0feff70..b4bed5bf40 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -26,7 +26,8 @@ from synapse.api.urls import (
CLIENT_API_PREFIX,
FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX,
- MEDIA_PREFIX,
+ MEDIA_R0_PREFIX,
+ MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
@@ -112,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import (
)
from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.room_batch import RoomBatchStore
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
@@ -239,6 +241,7 @@ class GenericWorkerSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomWorkerStore,
+ RoomBatchStore,
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
@@ -338,7 +341,8 @@ class GenericWorkerServer(HomeServer):
resources.update(
{
- MEDIA_PREFIX: media_repo,
+ MEDIA_R0_PREFIX: media_repo,
+ MEDIA_V3_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
"/_synapse/admin": admin_resource,
}
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7bb3744f04..52541faab2 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -29,7 +29,8 @@ from synapse import events
from synapse.api.urls import (
FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX,
- MEDIA_PREFIX,
+ MEDIA_R0_PREFIX,
+ MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX,
STATIC_PREFIX,
WEB_CLIENT_PREFIX,
@@ -193,6 +194,8 @@ class SynapseHomeServer(HomeServer):
{
"/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource,
+ "/_matrix/client/v1": client_resource,
+ "/_matrix/client/v3": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
@@ -244,7 +247,11 @@ class SynapseHomeServer(HomeServer):
if self.config.server.enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update(
- {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
+ {
+ MEDIA_R0_PREFIX: media_repo,
+ MEDIA_V3_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ }
)
elif name == "media":
raise ConfigError(
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index d08f6bbd7f..f51b636417 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -231,13 +231,32 @@ class ApplicationServiceApi(SimpleHttpClient):
json_body=body,
args={"access_token": service.hs_token},
)
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "push_bulk to %s succeeded! events=%s",
+ uri,
+ [event.get("event_id") for event in events],
+ )
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
return True
except CodeMessageException as e:
- logger.warning("push_bulk to %s received %s", uri, e.code)
+ logger.warning(
+ "push_bulk to %s received code=%s msg=%s",
+ uri,
+ e.code,
+ e.msg,
+ exc_info=logger.isEnabledFor(logging.DEBUG),
+ )
except Exception as ex:
- logger.warning("push_bulk to %s threw exception %s", uri, ex)
+ logger.warning(
+ "push_bulk to %s threw exception(%s) %s args=%s",
+ uri,
+ type(ex).__name__,
+ ex,
+ ex.args,
+ exc_info=logger.isEnabledFor(logging.DEBUG),
+ )
failed_transactions_counter.labels(service.id).inc()
return False
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7c4428a138..1265738dc1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,7 +20,18 @@ import os
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
-from typing import Any, Iterable, List, MutableMapping, Optional, Union
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
import attr
import jinja2
@@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\
"""
-def path_exists(file_path):
+def path_exists(file_path: str) -> bool:
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
@@ -86,7 +97,7 @@ def path_exists(file_path):
the parent dir).
Returns:
- bool: True if the file exists; False if not.
+ True if the file exists; False if not.
"""
try:
os.stat(file_path)
@@ -102,15 +113,15 @@ class Config:
A configuration section, containing configuration keys and values.
Attributes:
- section (str): The section title of this config object, such as
+ section: The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be
defined in subclasses.
"""
- section = None
+ section: str
- def __init__(self, root_config=None):
+ def __init__(self, root_config: "RootConfig" = None):
self.root = root_config
# Get the path to the default Synapse template directory
@@ -119,7 +130,7 @@ class Config:
)
@staticmethod
- def parse_size(value):
+ def parse_size(value: Union[str, int]) -> int:
if isinstance(value, int):
return value
sizes = {"K": 1024, "M": 1024 * 1024}
@@ -162,15 +173,15 @@ class Config:
return int(value) * size
@staticmethod
- def abspath(file_path):
+ def abspath(file_path: str) -> str:
return os.path.abspath(file_path) if file_path else file_path
@classmethod
- def path_exists(cls, file_path):
+ def path_exists(cls, file_path: str) -> bool:
return path_exists(file_path)
@classmethod
- def check_file(cls, file_path, config_name):
+ def check_file(cls, file_path: Optional[str], config_name: str) -> str:
if file_path is None:
raise ConfigError("Missing config for %s." % (config_name,))
try:
@@ -183,7 +194,7 @@ class Config:
return cls.abspath(file_path)
@classmethod
- def ensure_directory(cls, dir_path):
+ def ensure_directory(cls, dir_path: str) -> str:
dir_path = cls.abspath(dir_path)
os.makedirs(dir_path, exist_ok=True)
if not os.path.isdir(dir_path):
@@ -191,7 +202,7 @@ class Config:
return dir_path
@classmethod
- def read_file(cls, file_path, config_name):
+ def read_file(cls, file_path: Any, config_name: str) -> str:
"""Deprecated: call read_file directly"""
return read_file(file_path, (config_name,))
@@ -284,6 +295,9 @@ class Config:
return [env.get_template(filename) for filename in filenames]
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
+
+
class RootConfig:
"""
Holder of an application's configuration.
@@ -308,7 +322,9 @@ class RootConfig:
raise Exception("Failed making %s: %r" % (config_class.section, e))
setattr(self, config_class.section, conf)
- def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+ def invoke_all(
+ self, func_name: str, *args: Any, **kwargs: Any
+ ) -> MutableMapping[str, Any]:
"""
Invoke a function on all instantiated config objects this RootConfig is
configured to use.
@@ -317,6 +333,7 @@ class RootConfig:
func_name: Name of function to invoke
*args
**kwargs
+
Returns:
ordered dictionary of config section name and the result of the
function from it.
@@ -332,7 +349,7 @@ class RootConfig:
return res
@classmethod
- def invoke_all_static(cls, func_name: str, *args, **kwargs):
+ def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None:
"""
Invoke a static function on config objects this RootConfig is
configured to use.
@@ -341,6 +358,7 @@ class RootConfig:
func_name: Name of function to invoke
*args
**kwargs
+
Returns:
ordered dictionary of config section name and the result of the
function from it.
@@ -351,16 +369,16 @@ class RootConfig:
def generate_config(
self,
- config_dir_path,
- data_dir_path,
- server_name,
- generate_secrets=False,
- report_stats=None,
- open_private_ports=False,
- listeners=None,
- tls_certificate_path=None,
- tls_private_key_path=None,
- ):
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ generate_secrets: bool = False,
+ report_stats: Optional[bool] = None,
+ open_private_ports: bool = False,
+ listeners: Optional[List[dict]] = None,
+ tls_certificate_path: Optional[str] = None,
+ tls_private_key_path: Optional[str] = None,
+ ) -> str:
"""
Build a default configuration file
@@ -368,27 +386,27 @@ class RootConfig:
(eg with --generate_config).
Args:
- config_dir_path (str): The path where the config files are kept. Used to
+ config_dir_path: The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
- data_dir_path (str): The path where the data files are kept. Used to create
+ data_dir_path: The path where the data files are kept. Used to create
filenames for things like the database and media store.
- server_name (str): The server name. Used to initialise the server_name
+ server_name: The server name. Used to initialise the server_name
config param, but also used in the names of some of the config files.
- generate_secrets (bool): True if we should generate new secrets for things
+ generate_secrets: True if we should generate new secrets for things
like the macaroon_secret_key. If False, these parameters will be left
unset.
- report_stats (bool|None): Initial setting for the report_stats setting.
+ report_stats: Initial setting for the report_stats setting.
If None, report_stats will be left unset.
- open_private_ports (bool): True to leave private ports (such as the non-TLS
+ open_private_ports: True to leave private ports (such as the non-TLS
HTTP listener) open to the internet.
- listeners (list(dict)|None): A list of descriptions of the listeners
- synapse should start with each of which specifies a port (str), a list of
+ listeners: A list of descriptions of the listeners synapse should
+ start with each of which specifies a port (int), a list of
resources (list(str)), tls (bool) and type (str). For example:
[{
"port": 8448,
@@ -403,16 +421,12 @@ class RootConfig:
"type": "http",
}],
+ tls_certificate_path: The path to the tls certificate.
- database (str|None): The database type to configure, either `psycog2`
- or `sqlite3`.
-
- tls_certificate_path (str|None): The path to the tls certificate.
-
- tls_private_key_path (str|None): The path to the tls private key.
+ tls_private_key_path: The path to the tls private key.
Returns:
- str: the yaml config file
+ The yaml config file
"""
return CONFIG_FILE_HEADER + "\n\n".join(
@@ -432,12 +446,15 @@ class RootConfig:
)
@classmethod
- def load_config(cls, description, argv):
+ def load_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> TRootConfig:
"""Parse the commandline and config files
Doesn't support config-file-generation: used by the worker apps.
- Returns: Config object.
+ Returns:
+ Config object.
"""
config_parser = argparse.ArgumentParser(description=description)
cls.add_arguments_to_parser(config_parser)
@@ -446,7 +463,7 @@ class RootConfig:
return obj
@classmethod
- def add_arguments_to_parser(cls, config_parser):
+ def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None:
"""Adds all the config flags to an ArgumentParser.
Doesn't support config-file-generation: used by the worker apps.
@@ -454,7 +471,7 @@ class RootConfig:
Used for workers where we want to add extra flags/subcommands.
Args:
- config_parser (ArgumentParser): App description
+ config_parser: App description
"""
config_parser.add_argument(
@@ -477,7 +494,9 @@ class RootConfig:
cls.invoke_all_static("add_arguments", config_parser)
@classmethod
- def load_config_with_parser(cls, parser, argv):
+ def load_config_with_parser(
+ cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+ ) -> Tuple[TRootConfig, argparse.Namespace]:
"""Parse the commandline and config files with the given parser
Doesn't support config-file-generation: used by the worker apps.
@@ -485,13 +504,12 @@ class RootConfig:
Used for workers where we want to add extra flags/subcommands.
Args:
- parser (ArgumentParser)
- argv (list[str])
+ parser
+ argv
Returns:
- tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed
- config object and the parsed argparse.Namespace object from
- `parser.parse_args(..)`
+ Returns the parsed config object and the parsed argparse.Namespace
+ object from parser.parse_args(..)`
"""
obj = cls()
@@ -520,12 +538,15 @@ class RootConfig:
return obj, config_args
@classmethod
- def load_or_generate_config(cls, description, argv):
+ def load_or_generate_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> Optional[TRootConfig]:
"""Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app.
- Returns: Config object, or None if --generate-config or --generate-keys was set
+ Returns:
+ Config object, or None if --generate-config or --generate-keys was set
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
@@ -680,16 +701,21 @@ class RootConfig:
return obj
- def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
+ def parse_config_dict(
+ self,
+ config_dict: Dict[str, Any],
+ config_dir_path: Optional[str] = None,
+ data_dir_path: Optional[str] = None,
+ ) -> None:
"""Read the information from the config dict into this Config object.
Args:
- config_dict (dict): Configuration data, as read from the yaml
+ config_dict: Configuration data, as read from the yaml
- config_dir_path (str): The path where the config files are kept. Used to
+ config_dir_path: The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
- data_dir_path (str): The path where the data files are kept. Used to create
+ data_dir_path: The path where the data files are kept. Used to create
filenames for things like the database and media store.
"""
self.invoke_all(
@@ -699,17 +725,20 @@ class RootConfig:
data_dir_path=data_dir_path,
)
- def generate_missing_files(self, config_dict, config_dir_path):
+ def generate_missing_files(
+ self, config_dict: Dict[str, Any], config_dir_path: str
+ ) -> None:
self.invoke_all("generate_files", config_dict, config_dir_path)
-def read_config_files(config_files):
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
"""Read the config files into a dict
Args:
- config_files (iterable[str]): A list of the config files to read
+ config_files: A list of the config files to read
- Returns: dict
+ Returns:
+ The configuration dictionary.
"""
specified_config = {}
for config_file in config_files:
@@ -733,17 +762,17 @@ def read_config_files(config_files):
return specified_config
-def find_config_files(search_paths):
+def find_config_files(search_paths: List[str]) -> List[str]:
"""Finds config files using a list of search paths. If a path is a file
then that file path is added to the list. If a search path is a directory
then all the "*.yaml" files in that directory are added to the list in
sorted order.
Args:
- search_paths(list(str)): A list of paths to search.
+ search_paths: A list of paths to search.
Returns:
- list(str): A list of file paths.
+ A list of file paths.
"""
config_files = []
@@ -777,7 +806,7 @@ def find_config_files(search_paths):
return config_files
-@attr.s
+@attr.s(auto_attribs=True)
class ShardedWorkerHandlingConfig:
"""Algorithm for choosing which instance is responsible for handling some
sharded work.
@@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig:
below).
"""
- instances = attr.ib(type=List[str])
+ instances: List[str]
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key."""
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index c1d9069798..1eb5f5a68c 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,4 +1,18 @@
-from typing import Any, Iterable, List, Optional
+import argparse
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import jinja2
from synapse.config import (
account_validity,
@@ -19,6 +33,7 @@ from synapse.config import (
logger,
metrics,
modules,
+ oembed,
oidc,
password_auth_providers,
push,
@@ -27,6 +42,7 @@ from synapse.config import (
registration,
repository,
retention,
+ room,
room_directory,
saml2,
server,
@@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str
-def path_exists(file_path: str): ...
+def path_exists(file_path: str) -> bool: ...
+
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
class RootConfig:
server: server.ServerConfig
@@ -61,6 +79,7 @@ class RootConfig:
logging: logger.LoggingConfig
ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
+ oembed: oembed.OembedConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
@@ -80,6 +99,7 @@ class RootConfig:
authproviders: password_auth_providers.PasswordAuthProviderConfig
push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig
+ room: room.RoomConfig
groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig
consent: consent.ConsentConfig
@@ -87,72 +107,85 @@ class RootConfig:
servernotices: server_notices.ServerNoticesConfig
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
- tracer: tracer.TracerConfig
+ tracing: tracer.TracerConfig
redis: redis.RedisConfig
modules: modules.ModulesConfig
caches: cache.CacheConfig
federation: federation.FederationConfig
retention: retention.RetentionConfig
- config_classes: List = ...
+ config_classes: List[Type["Config"]] = ...
def __init__(self) -> None: ...
- def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
+ def invoke_all(
+ self, func_name: str, *args: Any, **kwargs: Any
+ ) -> MutableMapping[str, Any]: ...
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
- def __getattr__(self, item: str): ...
def parse_config_dict(
self,
- config_dict: Any,
- config_dir_path: Optional[Any] = ...,
- data_dir_path: Optional[Any] = ...,
+ config_dict: Dict[str, Any],
+ config_dir_path: Optional[str] = ...,
+ data_dir_path: Optional[str] = ...,
) -> None: ...
- read_config: Any = ...
def generate_config(
self,
config_dir_path: str,
data_dir_path: str,
server_name: str,
generate_secrets: bool = ...,
- report_stats: Optional[str] = ...,
+ report_stats: Optional[bool] = ...,
open_private_ports: bool = ...,
listeners: Optional[Any] = ...,
- database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
- ): ...
+ ) -> str: ...
@classmethod
- def load_or_generate_config(cls, description: Any, argv: Any): ...
+ def load_or_generate_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> Optional[TRootConfig]: ...
@classmethod
- def load_config(cls, description: Any, argv: Any): ...
+ def load_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> TRootConfig: ...
@classmethod
- def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
+ def add_arguments_to_parser(
+ cls, config_parser: argparse.ArgumentParser
+ ) -> None: ...
@classmethod
- def load_config_with_parser(cls, parser: Any, argv: Any): ...
+ def load_config_with_parser(
+ cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+ ) -> Tuple[TRootConfig, argparse.Namespace]: ...
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
class Config:
root: RootConfig
+ default_template_dir: str
def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
- def __getattr__(self, item: str, from_root: bool = ...): ...
@staticmethod
- def parse_size(value: Any): ...
+ def parse_size(value: Union[str, int]) -> int: ...
@staticmethod
- def parse_duration(value: Any): ...
+ def parse_duration(value: Union[str, int]) -> int: ...
@staticmethod
- def abspath(file_path: Optional[str]): ...
+ def abspath(file_path: Optional[str]) -> str: ...
@classmethod
- def path_exists(cls, file_path: str): ...
+ def path_exists(cls, file_path: str) -> bool: ...
@classmethod
- def check_file(cls, file_path: str, config_name: str): ...
+ def check_file(cls, file_path: str, config_name: str) -> str: ...
@classmethod
- def ensure_directory(cls, dir_path: str): ...
+ def ensure_directory(cls, dir_path: str) -> str: ...
@classmethod
- def read_file(cls, file_path: str, config_name: str): ...
+ def read_file(cls, file_path: str, config_name: str) -> str: ...
+ def read_template(self, filenames: str) -> jinja2.Template: ...
+ def read_templates(
+ self,
+ filenames: List[str],
+ custom_template_directories: Optional[Iterable[str]] = None,
+ ) -> List[jinja2.Template]: ...
-def read_config_files(config_files: List[str]): ...
-def find_config_files(search_paths: List[str]): ...
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ...
+def find_config_files(search_paths: List[str]) -> List[str]: ...
class ShardedWorkerHandlingConfig:
instances: List[str]
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index d119427ad8..f054455534 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -15,7 +15,7 @@
import os
import re
import threading
-from typing import Callable, Dict
+from typing import Callable, Dict, Optional
from synapse.python_dependencies import DependencyException, check_requirements
@@ -217,7 +217,7 @@ class CacheConfig(Config):
expiry_time = cache_config.get("expiry_time")
if expiry_time:
- self.expiry_time_msec = self.parse_duration(expiry_time)
+ self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time)
else:
self.expiry_time_msec = None
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index afd65fecd3..510b647c63 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -137,33 +137,14 @@ class EmailConfig(Config):
if self.root.registration.account_threepid_delegate_email
else ThreepidBehaviour.LOCAL
)
- # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would
- # use an identity server to password reset tokens on its behalf. We now warn the user
- # if they have this set and tell them to use the updated option, while using a default
- # identity server in the process.
- self.using_identity_server_from_trusted_list = False
- if (
- not self.root.registration.account_threepid_delegate_email
- and config.get("trust_identity_server_for_password_resets", False) is True
- ):
- # Use the first entry in self.trusted_third_party_id_servers instead
- if self.trusted_third_party_id_servers:
- # XXX: It's a little confusing that account_threepid_delegate_email is modified
- # both in RegistrationConfig and here. We should factor this bit out
- first_trusted_identity_server = self.trusted_third_party_id_servers[0]
-
- # trusted_third_party_id_servers does not contain a scheme whereas
- # account_threepid_delegate_email is expected to. Presume https
- self.root.registration.account_threepid_delegate_email = (
- "https://" + first_trusted_identity_server
- )
- self.using_identity_server_from_trusted_list = True
- else:
- raise ConfigError(
- "Attempted to use an identity server from"
- '"trusted_third_party_id_servers" but it is empty.'
- )
+ if config.get("trust_identity_server_for_password_resets"):
+ raise ConfigError(
+ 'The config option "trust_identity_server_for_password_resets" '
+ 'has been replaced by "account_threepid_delegate". '
+ "Please consult the sample config at docs/sample_config.yaml for "
+ "details and update your config file."
+ )
self.local_threepid_handling_disabled_due_to_email_config = False
if (
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 9d295f5856..24c3ef01fc 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -31,6 +31,8 @@ class JWTConfig(Config):
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
+ self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
+
# The issuer and audiences are optional, if provided, it is asserted
# that the claims exist on the JWT.
self.jwt_issuer = jwt_config.get("issuer")
@@ -46,6 +48,7 @@ class JWTConfig(Config):
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
+ self.jwt_subject_claim = None
self.jwt_issuer = None
self.jwt_audiences = None
@@ -88,6 +91,12 @@ class JWTConfig(Config):
#
#algorithm: "provided-by-your-issuer"
+ # Name of the claim containing a unique identifier for the user.
+ #
+ # Optional, defaults to `sub`.
+ #
+ #subject_claim: "sub"
+
# The issuer to validate the "iss" claim against.
#
# Optional, if provided the "iss" claim will be required and
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 015dbb8a67..035ee2416b 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,6 +16,7 @@
import hashlib
import logging
import os
+from typing import Any, Dict
import attr
import jsonschema
@@ -312,7 +313,7 @@ class KeyConfig(Config):
)
return keys
- def generate_files(self, config, config_dir_path):
+ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
if "signing_key" in config:
return
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 5252e61a99..63aab0babe 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -18,7 +18,7 @@ import os
import sys
import threading
from string import Template
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict
import yaml
from zope.interface import implementer
@@ -185,7 +185,7 @@ class LoggingConfig(Config):
help=argparse.SUPPRESS,
)
- def generate_files(self, config, config_dir_path):
+ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 5379e80715..1ddad7cb70 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
@@ -39,9 +40,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
- self.trusted_third_party_id_servers = config.get(
- "trusted_third_party_id_servers", ["matrix.org", "vector.im"]
- )
+
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
@@ -114,26 +113,25 @@ class RegistrationConfig(Config):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
- # The `access_token_lifetime` applies for tokens that can be renewed
- # using a refresh token, as per MSC2918. If it is `None`, the refresh
- # token mechanism is disabled.
- #
- # Since it is incompatible with the `session_lifetime` mechanism, it is set to
- # `None` by default if a `session_lifetime` is set.
- access_token_lifetime = config.get(
- "access_token_lifetime", "5m" if session_lifetime is None else None
+ # The `refreshable_access_token_lifetime` applies for tokens that can be renewed
+ # using a refresh token, as per MSC2918.
+ # If it is `None`, the refresh token mechanism is disabled.
+ refreshable_access_token_lifetime = config.get(
+ "refreshable_access_token_lifetime",
+ "5m",
)
- if access_token_lifetime is not None:
- access_token_lifetime = self.parse_duration(access_token_lifetime)
- self.access_token_lifetime = access_token_lifetime
-
- if session_lifetime is not None and access_token_lifetime is not None:
- raise ConfigError(
- "The refresh token mechanism is incompatible with the "
- "`session_lifetime` option. Consider disabling the "
- "`session_lifetime` option or disabling the refresh token "
- "mechanism by removing the `access_token_lifetime` option."
+ if refreshable_access_token_lifetime is not None:
+ refreshable_access_token_lifetime = self.parse_duration(
+ refreshable_access_token_lifetime
)
+ self.refreshable_access_token_lifetime: Optional[
+ int
+ ] = refreshable_access_token_lifetime
+
+ refresh_token_lifetime = config.get("refresh_token_lifetime")
+ if refresh_token_lifetime is not None:
+ refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
+ self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime
# The fallback template used for authenticating using a registration token
self.registration_token_template = self.read_template("registration_token.html")
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 56981cac79..57316c59b6 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -1,4 +1,5 @@
# Copyright 2018 New Vector Ltd
+# Copyright 2021 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
+
+from synapse.types import JsonDict
from synapse.util import glob_to_regex
from ._base import Config, ConfigError
@@ -20,7 +24,7 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
section = "roomdirectory"
- def read_config(self, config, **kwargs):
+ def read_config(self, config, **kwargs) -> None:
self.enable_room_list_search = config.get("enable_room_list_search", True)
alias_creation_rules = config.get("alias_creation_rules")
@@ -47,7 +51,7 @@ class RoomDirectoryConfig(Config):
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
]
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """
# Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote
@@ -113,16 +117,16 @@ class RoomDirectoryConfig(Config):
# action: allow
"""
- def is_alias_creation_allowed(self, user_id, room_id, alias):
+ def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool:
"""Checks if the given user is allowed to create the given alias
Args:
- user_id (str)
- room_id (str)
- alias (str)
+ user_id: The user to check.
+ room_id: The room ID for the alias.
+ alias: The alias being created.
Returns:
- boolean: True if user is allowed to create the alias
+ True if user is allowed to create the alias
"""
for rule in self._alias_creation_rules:
if rule.matches(user_id, room_id, [alias]):
@@ -130,16 +134,18 @@ class RoomDirectoryConfig(Config):
return False
- def is_publishing_room_allowed(self, user_id, room_id, aliases):
+ def is_publishing_room_allowed(
+ self, user_id: str, room_id: str, aliases: List[str]
+ ) -> bool:
"""Checks if the given user is allowed to publish the room
Args:
- user_id (str)
- room_id (str)
- aliases (list[str]): any local aliases associated with the room
+ user_id: The user ID publishing the room.
+ room_id: The room being published.
+ aliases: any local aliases associated with the room
Returns:
- boolean: True if user can publish room
+ True if user can publish room
"""
for rule in self._room_list_publication_rules:
if rule.matches(user_id, room_id, aliases):
@@ -153,11 +159,11 @@ class _RoomDirectoryRule:
creating an alias or publishing a room.
"""
- def __init__(self, option_name, rule):
+ def __init__(self, option_name: str, rule: JsonDict):
"""
Args:
- option_name (str): Name of the config option this rule belongs to
- rule (dict): The rule as specified in the config
+ option_name: Name of the config option this rule belongs to
+ rule: The rule as specified in the config
"""
action = rule["action"]
@@ -181,18 +187,18 @@ class _RoomDirectoryRule:
except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e
- def matches(self, user_id, room_id, aliases):
+ def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases.
Args:
- user_id (str)
- room_id (str)
- aliases (list[str]): The associated aliases to the room. Will be a
- single element for testing alias creation, and can be empty for
- testing room publishing.
+ user_id: The user ID to check.
+ room_id: The room ID to check.
+ aliases: The associated aliases to the room. Will be a single element
+ for testing alias creation, and can be empty for testing room
+ publishing.
Returns:
- boolean
+ True if the rule matches.
"""
# Note: The regexes are anchored at both ends
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7bc0030a9e..8445e9dd05 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -421,7 +421,7 @@ class ServerConfig(Config):
# before redacting them.
redaction_retention_period = config.get("redaction_retention_period", "7d")
if redaction_retention_period is not None:
- self.redaction_retention_period = self.parse_duration(
+ self.redaction_retention_period: Optional[int] = self.parse_duration(
redaction_retention_period
)
else:
@@ -430,7 +430,7 @@ class ServerConfig(Config):
# How long to keep entries in the `users_ips` table.
user_ips_max_age = config.get("user_ips_max_age", "28d")
if user_ips_max_age is not None:
- self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+ self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age)
else:
self.user_ips_max_age = None
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 6227434bac..4ca111618f 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -14,7 +14,6 @@
import logging
import os
-from datetime import datetime
from typing import List, Optional, Pattern
from OpenSSL import SSL, crypto
@@ -133,55 +132,6 @@ class TlsConfig(Config):
self.tls_certificate: Optional[crypto.X509] = None
self.tls_private_key: Optional[crypto.PKey] = None
- def is_disk_cert_valid(self, allow_self_signed=True):
- """
- Is the certificate we have on disk valid, and if so, for how long?
-
- Args:
- allow_self_signed (bool): Should we allow the certificate we
- read to be self signed?
-
- Returns:
- int: Days remaining of certificate validity.
- None: No certificate exists.
- """
- if not os.path.exists(self.tls_certificate_file):
- return None
-
- try:
- with open(self.tls_certificate_file, "rb") as f:
- cert_pem = f.read()
- except Exception as e:
- raise ConfigError(
- "Failed to read existing certificate file %s: %s"
- % (self.tls_certificate_file, e)
- )
-
- try:
- tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
- except Exception as e:
- raise ConfigError(
- "Failed to parse existing certificate file %s: %s"
- % (self.tls_certificate_file, e)
- )
-
- if not allow_self_signed:
- if tls_certificate.get_subject() == tls_certificate.get_issuer():
- raise ValueError(
- "TLS Certificate is self signed, and this is not permitted"
- )
-
- # YYYYMMDDhhmmssZ -- in UTC
- expiry_data = tls_certificate.get_notAfter()
- if expiry_data is None:
- raise ValueError(
- "TLS Certificate has no expiry date, and this is not permitted"
- )
- expires_on = datetime.strptime(expiry_data.decode("ascii"), "%Y%m%d%H%M%SZ")
- now = datetime.utcnow()
- days_remaining = (expires_on - now).days
- return days_remaining
-
def read_certificate_from_disk(self):
"""
Read the certificates and private key from disk.
@@ -263,8 +213,8 @@ class TlsConfig(Config):
#
#federation_certificate_verification_whitelist:
# - lon.example.com
- # - *.domain.com
- # - *.onion
+ # - "*.domain.com"
+ # - "*.onion"
# List of custom certificate authorities for federation traffic.
#
@@ -295,7 +245,7 @@ class TlsConfig(Config):
cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path)
cert_pem = self.read_file(cert_path, "tls_certificate_path")
- cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+ cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode())
return cert
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 2552f688d0..6d6678c7e4 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -53,8 +53,8 @@ class UserDirectoryConfig(Config):
# indexes were (re)built was before Synapse 1.44, you'll have to
# rebuild the indexes in order to search through all known users.
# These indexes are built the first time Synapse starts; admins can
- # manually trigger a rebuild following the instructions at
- # https://matrix-org.github.io/synapse/latest/user_directory.html
+ # manually trigger a rebuild via API following the instructions at
+ # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f641ab7ef5..993b04099e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -120,16 +119,6 @@ class VerifyJsonRequest:
key_ids=key_ids,
)
- def to_fetch_key_request(self) -> "_FetchKeyRequest":
- """Create a key fetch request for all keys needed to satisfy the
- verification request.
- """
- return _FetchKeyRequest(
- server_name=self.server_name,
- minimum_valid_until_ts=self.minimum_valid_until_ts,
- key_ids=self.key_ids,
- )
-
class KeyLookupError(ValueError):
pass
@@ -179,8 +168,22 @@ class Keyring:
clock=hs.get_clock(),
process_batch_callback=self._inner_fetch_key_requests,
)
- self.verify_key = get_verify_key(hs.signing_key)
- self.hostname = hs.hostname
+
+ self._hostname = hs.hostname
+
+ # build a FetchKeyResult for each of our own keys, to shortcircuit the
+ # fetcher.
+ self._local_verify_keys: Dict[str, FetchKeyResult] = {}
+ for key_id, key in hs.config.key.old_signing_keys.items():
+ self._local_verify_keys[key_id] = FetchKeyResult(
+ verify_key=key, valid_until_ts=key.expired_ts
+ )
+
+ vk = get_verify_key(hs.signing_key)
+ self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
+ verify_key=vk,
+ valid_until_ts=2 ** 63, # fake future timestamp
+ )
async def verify_json_for_server(
self,
@@ -267,22 +270,32 @@ class Keyring:
Codes.UNAUTHORIZED,
)
- # If we are the originating server don't fetch verify key for self over federation
- if verify_request.server_name == self.hostname:
- await self._process_json(self.verify_key, verify_request)
- return
+ found_keys: Dict[str, FetchKeyResult] = {}
- # Add the keys we need to verify to the queue for retrieval. We queue
- # up requests for the same server so we don't end up with many in flight
- # requests for the same keys.
- key_request = verify_request.to_fetch_key_request()
- found_keys_by_server = await self._server_queue.add_to_queue(
- key_request, key=verify_request.server_name
- )
+ # If we are the originating server, short-circuit the key-fetch for any keys
+ # we already have
+ if verify_request.server_name == self._hostname:
+ for key_id in verify_request.key_ids:
+ if key_id in self._local_verify_keys:
+ found_keys[key_id] = self._local_verify_keys[key_id]
+
+ key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
+ if key_ids_to_find:
+ # Add the keys we need to verify to the queue for retrieval. We queue
+ # up requests for the same server so we don't end up with many in flight
+ # requests for the same keys.
+ key_request = _FetchKeyRequest(
+ server_name=verify_request.server_name,
+ minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
+ key_ids=list(key_ids_to_find),
+ )
+ found_keys_by_server = await self._server_queue.add_to_queue(
+ key_request, key=verify_request.server_name
+ )
- # Since we batch up requests the returned set of keys may contain keys
- # from other servers, so we pull out only the ones we care about.s
- found_keys = found_keys_by_server.get(verify_request.server_name, {})
+ # Since we batch up requests the returned set of keys may contain keys
+ # from other servers, so we pull out only the ones we care about.
+ found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
# Verify each signature we got valid keys for, raising if we can't
# verify any of them.
@@ -654,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name,
)
+ request: JsonDict = {}
+ for queue_value in keys_to_fetch:
+ # there may be multiple requests for each server, so we have to merge
+ # them intelligently.
+ request_for_server = {
+ key_id: {
+ "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+ }
+ for key_id in queue_value.key_ids
+ }
+ request.setdefault(queue_value.server_name, {}).update(request_for_server)
+
+ logger.debug("Request to notary server %s: %s", perspective_name, request)
+
try:
query_response = await self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
- data={
- "server_keys": {
- queue_value.server_name: {
- key_id: {
- "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
- }
- for key_id in queue_value.key_ids
- }
- for queue_value in keys_to_fetch
- }
- },
+ data={"server_keys": request},
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon
@@ -676,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
+ logger.debug(
+ "Response from notary server %s: %s", perspective_name, query_response
+ )
+
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
added_keys: List[Tuple[str, str, FetchKeyResult]] = []
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index d7527008c4..f251402ed8 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -322,6 +322,11 @@ class _AsyncEventContextImpl(EventContext):
attributes by loading from the database.
"""
if self.state_group is None:
+ # No state group means the event is an outlier. Usually the state_ids dicts are also
+ # pre-set to empty dicts, but they get reset when the context is serialized, so set
+ # them to empty dicts again here.
+ self._current_state_ids = {}
+ self._prev_state_ids = {}
return
current_state_ids = await self._storage.state.get_state_ids_for_group(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 6fa631aa1d..e5967c995e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -392,15 +393,16 @@ class EventClientSerializer:
self,
event: Union[JsonDict, EventBase],
time_now: int,
- bundle_aggregations: bool = True,
+ bundle_relations: bool = True,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Args:
- event
+ event: The event being serialized.
time_now: The current time in milliseconds
- bundle_aggregations: Whether to bundle in related events
+ bundle_relations: Whether to include the bundled relations for this
+ event.
**kwargs: Arguments to pass to `serialize_event`
Returns:
@@ -410,77 +412,93 @@ class EventClientSerializer:
if not isinstance(event, EventBase):
return event
- event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
# If MSC1849 is enabled then we need to look if there are any relations
# we need to bundle in with the event.
# Do not bundle relations if the event has been redacted
if not event.internal_metadata.is_redacted() and (
- self._msc1849_enabled and bundle_aggregations
+ self._msc1849_enabled and bundle_relations
):
- annotations = await self.store.get_aggregation_groups_for_event(event_id)
- references = await self.store.get_relations_for_event(
- event_id, RelationTypes.REFERENCE, direction="f"
- )
-
- if annotations.chunk:
- r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.ANNOTATION] = annotations.to_dict()
-
- if references.chunk:
- r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.REFERENCE] = references.to_dict()
-
- edit = None
- if event.type == EventTypes.Message:
- edit = await self.store.get_applicable_edit(event_id)
-
- if edit:
- # If there is an edit replace the content, preserving existing
- # relations.
-
- # Ensure we take copies of the edit content, otherwise we risk modifying
- # the original event.
- edit_content = edit.content.copy()
-
- # Unfreeze the event content if necessary, so that we may modify it below
- edit_content = unfreeze(edit_content)
- serialized_event["content"] = edit_content.get("m.new_content", {})
-
- # Check for existing relations
- relations = event.content.get("m.relates_to")
- if relations:
- # Keep the relations, ensuring we use a dict copy of the original
- serialized_event["content"]["m.relates_to"] = relations.copy()
- else:
- serialized_event["content"].pop("m.relates_to", None)
-
- r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.REPLACE] = {
- "event_id": edit.event_id,
- "origin_server_ts": edit.origin_server_ts,
- "sender": edit.sender,
- }
-
- # If this event is the start of a thread, include a summary of the replies.
- if self._msc3440_enabled:
- (
- thread_count,
- latest_thread_event,
- ) = await self.store.get_thread_summary(event_id)
- if latest_thread_event:
- r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.THREAD] = {
- # Don't bundle aggregations as this could recurse forever.
- "latest_event": await self.serialize_event(
- latest_thread_event, time_now, bundle_aggregations=False
- ),
- "count": thread_count,
- }
+ await self._injected_bundled_relations(event, time_now, serialized_event)
return serialized_event
+ async def _injected_bundled_relations(
+ self, event: EventBase, time_now: int, serialized_event: JsonDict
+ ) -> None:
+ """Potentially injects bundled relations into the unsigned portion of the serialized event.
+
+ Args:
+ event: The event being serialized.
+ time_now: The current time in milliseconds
+ serialized_event: The serialized event which may be modified.
+
+ """
+ event_id = event.event_id
+
+ # The bundled relations to include.
+ relations = {}
+
+ annotations = await self.store.get_aggregation_groups_for_event(event_id)
+ if annotations.chunk:
+ relations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ references = await self.store.get_relations_for_event(
+ event_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ relations[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = await self.store.get_applicable_edit(event_id)
+
+ if edit:
+ # If there is an edit replace the content, preserving existing
+ # relations.
+
+ # Ensure we take copies of the edit content, otherwise we risk modifying
+ # the original event.
+ edit_content = edit.content.copy()
+
+ # Unfreeze the event content if necessary, so that we may modify it below
+ edit_content = unfreeze(edit_content)
+ serialized_event["content"] = edit_content.get("m.new_content", {})
+
+ # Check for existing relations
+ relates_to = event.content.get("m.relates_to")
+ if relates_to:
+ # Keep the relations, ensuring we use a dict copy of the original
+ serialized_event["content"]["m.relates_to"] = relates_to.copy()
+ else:
+ serialized_event["content"].pop("m.relates_to", None)
+
+ relations[RelationTypes.REPLACE] = {
+ "event_id": edit.event_id,
+ "origin_server_ts": edit.origin_server_ts,
+ "sender": edit.sender,
+ }
+
+ # If this event is the start of a thread, include a summary of the replies.
+ if self._msc3440_enabled:
+ (
+ thread_count,
+ latest_thread_event,
+ ) = await self.store.get_thread_summary(event_id)
+ if latest_thread_event:
+ relations[RelationTypes.THREAD] = {
+ # Don't bundle relations as this could recurse forever.
+ "latest_event": await self.serialize_event(
+ latest_thread_event, time_now, bundle_relations=False
+ ),
+ "count": thread_count,
+ }
+
+ # If any bundled relations were found, include them.
+ if relations:
+ serialized_event["unsigned"].setdefault("m.relations", {}).update(relations)
+
async def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3b85b135e0..bc3f96c1fc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1395,11 +1395,28 @@ class FederationClient(FederationBase):
async def send_request(
destination: str,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
- res = await self.transport_layer.get_room_hierarchy(
- destination=destination,
- room_id=room_id,
- suggested_only=suggested_only,
- )
+ try:
+ res = await self.transport_layer.get_room_hierarchy(
+ destination=destination,
+ room_id=room_id,
+ suggested_only=suggested_only,
+ )
+ except HttpResponseException as e:
+ # If an error is received that is due to an unrecognised endpoint,
+ # fallback to the unstable endpoint. Otherwise consider it a
+ # legitmate error and raise.
+ if not self._is_unknown_endpoint(e):
+ raise
+
+ logger.debug(
+ "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API"
+ )
+
+ res = await self.transport_layer.get_room_hierarchy_unstable(
+ destination=destination,
+ room_id=room_id,
+ suggested_only=suggested_only,
+ )
room = res.get("room")
if not isinstance(room, dict):
@@ -1449,6 +1466,10 @@ class FederationClient(FederationBase):
if e.code != 502:
raise
+ logger.debug(
+ "Couldn't fetch room hierarchy, falling back to the spaces API"
+ )
+
# Fallback to the old federation API and translate the results if
# no servers implement the new API.
#
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9a8758e9a6..8fbc75aa65 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -613,8 +613,11 @@ class FederationServer(FederationBase):
state = await self.store.get_events(state_ids)
time_now = self._clock.time_msec()
+ event_json = event.get_pdu_json()
return {
- "org.matrix.msc3083.v2.event": event.get_pdu_json(),
+ # TODO Remove the unstable prefix when servers have updated.
+ "org.matrix.msc3083.v2.event": event_json,
+ "event": event_json,
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 10b5aa5af8..fe29bcfd4b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1192,10 +1192,24 @@ class TransportLayerClient:
)
async def get_room_hierarchy(
- self,
- destination: str,
- room_id: str,
- suggested_only: bool,
+ self, destination: str, room_id: str, suggested_only: bool
+ ) -> JsonDict:
+ """
+ Args:
+ destination: The remote server
+ room_id: The room ID to ask about.
+ suggested_only: if True, only suggested rooms will be returned
+ """
+ path = _create_v1_path("/hierarchy/%s", room_id)
+
+ return await self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"suggested_only": "true" if suggested_only else "false"},
+ )
+
+ async def get_room_hierarchy_unstable(
+ self, destination: str, room_id: str, suggested_only: bool
) -> JsonDict:
"""
Args:
@@ -1317,15 +1331,26 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
prefix + "auth_chain.item",
use_float=True,
)
- self._coro_event = ijson.kvitems_coro(
+ # TODO Remove the unstable prefix when servers have updated.
+ #
+ # By re-using the same event dictionary this will cause the parsing of
+ # org.matrix.msc3083.v2.event and event to stomp over each other.
+ # Generally this should be fine.
+ self._coro_unstable_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
use_float=True,
)
+ self._coro_event = ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "event",
+ use_float=True,
+ )
def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
+ self._coro_unstable_event.send(data)
self._coro_event.send(data)
return len(data)
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 2fdf6cc99e..66e915228c 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -611,7 +611,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
class FederationRoomHierarchyServlet(BaseFederationServlet):
- PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
PATH = "/hierarchy/(?P<room_id>[^/]*)"
def __init__(
@@ -637,6 +636,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
)
+class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet):
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
+
+
class RoomComplexityServlet(BaseFederationServlet):
"""
Indicates to other servers how complex (and therefore likely
@@ -701,6 +704,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
RoomComplexityServlet,
FederationSpaceSummaryServlet,
FederationRoomHierarchyServlet,
+ FederationRoomHierarchyUnstableServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
)
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 53f99031b1..a87896e538 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json
+from twisted.internet.defer import Deferred
+
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict, get_domain_from_id
@@ -166,7 +168,7 @@ class GroupAttestionRenewer:
return {}
- def _start_renew_attestations(self) -> None:
+ def _start_renew_attestations(self) -> "Deferred[None]":
return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self) -> None:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b62e13b725..4d9c4e5834 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
import unicodedata
import urllib.parse
from binascii import crc32
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -756,53 +757,109 @@ class AuthHandler:
async def refresh_token(
self,
refresh_token: str,
- valid_until_ms: Optional[int],
- ) -> Tuple[str, str]:
+ access_token_valid_until_ms: Optional[int],
+ refresh_token_valid_until_ms: Optional[int],
+ ) -> Tuple[str, str, Optional[int]]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+ The lifetime of both the access token and refresh token will be capped so that they
+ do not exceed the session's ultimate expiry time, if applicable.
+
Args:
refresh_token: The token to consume.
- valid_until_ms: The expiration timestamp of the new access token.
-
+ access_token_valid_until_ms: The expiration timestamp of the new access token.
+ None if the access token does not expire.
+ refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+ None if the refresh token does not expire.
Returns:
- A tuple containing the new access token and refresh token
+ A tuple containing:
+ - the new access token
+ - the new refresh token
+ - the actual expiry time of the access token, which may be earlier than
+ `access_token_valid_until_ms`.
"""
# Verify the token signature first before looking up the token
if not self._verify_refresh_token(refresh_token):
- raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+ raise SynapseError(
+ HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+ )
existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
- raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+ raise SynapseError(
+ HTTPStatus.UNAUTHORIZED,
+ "refresh token does not exist",
+ Codes.UNKNOWN_TOKEN,
+ )
if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(
- 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+ HTTPStatus.FORBIDDEN,
+ "refresh token isn't valid anymore",
+ Codes.FORBIDDEN,
+ )
+
+ now_ms = self._clock.time_msec()
+
+ if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "The supplied refresh token has expired",
+ Codes.FORBIDDEN,
)
+ if existing_token.ultimate_session_expiry_ts is not None:
+ # This session has a bounded lifetime, even across refreshes.
+
+ if access_token_valid_until_ms is not None:
+ access_token_valid_until_ms = min(
+ access_token_valid_until_ms,
+ existing_token.ultimate_session_expiry_ts,
+ )
+ else:
+ access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+ if refresh_token_valid_until_ms is not None:
+ refresh_token_valid_until_ms = min(
+ refresh_token_valid_until_ms,
+ existing_token.ultimate_session_expiry_ts,
+ )
+ else:
+ refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+ if existing_token.ultimate_session_expiry_ts < now_ms:
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "The session has expired and can no longer be refreshed",
+ Codes.FORBIDDEN,
+ )
+
(
new_refresh_token,
new_refresh_token_id,
- ) = await self.get_refresh_token_for_user_id(
- user_id=existing_token.user_id, device_id=existing_token.device_id
+ ) = await self.create_refresh_token_for_user_id(
+ user_id=existing_token.user_id,
+ device_id=existing_token.device_id,
+ expiry_ts=refresh_token_valid_until_ms,
+ ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
)
- access_token = await self.get_access_token_for_user_id(
+ access_token = await self.create_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
- valid_until_ms=valid_until_ms,
+ valid_until_ms=access_token_valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
- return access_token, new_refresh_token
+ return access_token, new_refresh_token, access_token_valid_until_ms
def _verify_refresh_token(self, token: str) -> bool:
"""
@@ -832,10 +889,12 @@ class AuthHandler:
return True
- async def get_refresh_token_for_user_id(
+ async def create_refresh_token_for_user_id(
self,
user_id: str,
device_id: str,
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> Tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@@ -843,6 +902,13 @@ class AuthHandler:
Args:
user_id: canonical user ID
device_id: the device ID to associate with the token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Returns:
The newly created refresh token and its ID in the database
@@ -852,10 +918,12 @@ class AuthHandler:
user_id=user_id,
token=refresh_token,
device_id=device_id,
+ expiry_ts=expiry_ts,
+ ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
return refresh_token, refresh_token_id
- async def get_access_token_for_user_id(
+ async def create_access_token_for_user_id(
self,
user_id: str,
device_id: Optional[str],
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1f64534a8a..b4ff935546 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -124,7 +124,7 @@ class EventStreamHandler:
as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
- bundle_aggregations=False,
+ bundle_relations=False,
)
chunk = {
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 3dbe611f95..c83eaea359 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -464,15 +464,6 @@ class IdentityHandler:
if next_link:
params["next_link"] = next_link
- if self.hs.config.email.using_identity_server_from_trusted_list:
- # Warn that a deprecated config option is in use
- logger.warning(
- 'The config option "trust_identity_server_for_password_resets" '
- 'has been replaced by "account_threepid_delegate". '
- "Please consult the sample config at docs/sample_config.yaml for "
- "details and update your config file."
- )
-
try:
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
@@ -517,15 +508,6 @@ class IdentityHandler:
if next_link:
params["next_link"] = next_link
- if self.hs.config.email.using_identity_server_from_trusted_list:
- # Warn that a deprecated config option is in use
- logger.warning(
- 'The config option "trust_identity_server_for_password_resets" '
- 'has been replaced by "account_threepid_delegate". '
- "Please consult the sample config at docs/sample_config.yaml for "
- "details and update your config file."
- )
-
try:
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d4c2a6ab7a..95b4fad3c6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -252,7 +252,7 @@ class MessageHandler:
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
- bundle_aggregations=False,
+ bundle_relations=False,
)
return events
@@ -1001,13 +1001,52 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
+ await self._validate_event_relation(event)
+ logger.debug("Created event %s", event.event_id)
+
+ return event, context
+
+ async def _validate_event_relation(self, event: EventBase) -> None:
+ """
+ Ensure the relation data on a new event is not bogus.
+
+ Args:
+ event: The event being created.
+
+ Raises:
+ SynapseError if the event is invalid.
+ """
+
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ return
+
+ relation_type = relation.get("rel_type")
+ if not relation_type:
+ return
+
+ # Ensure the parent is real.
+ relates_to = relation.get("event_id")
+ if not relates_to:
+ return
+
+ parent_event = await self.store.get_event(relates_to, allow_none=True)
+ if parent_event:
+ # And in the same room.
+ if parent_event.room_id != event.room_id:
+ raise SynapseError(400, "Relations must be in the same room")
+
+ else:
+ # There must be some reason that the client knows the event exists,
+ # see if there are existing relations. If so, assume everything is fine.
+ if not await self.store.event_is_target_of_relation(relates_to):
+ # Otherwise, the client can't know about the parent event!
+ raise SynapseError(400, "Can't send relation to unknown event")
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
- relation = event.content.get("m.relates_to", {})
- if relation.get("rel_type") == RelationTypes.ANNOTATION:
- relates_to = relation["event_id"]
+ if relation_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"]
already_exists = await self.store.has_user_annotated_event(
@@ -1016,9 +1055,12 @@ class EventCreationHandler:
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
- logger.debug("Created event %s", event.event_id)
-
- return event, context
+ # Don't attempt to start a thread if the parent event is a relation.
+ elif relation_type == RelationTypes.THREAD:
+ if await self.store.event_includes_relation(relates_to):
+ raise SynapseError(
+ 400, "Cannot start threads from an event with a relation"
+ )
@measure_func("handle_new_client_event")
async def handle_new_client_event(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a0e6a01775..24ca11b924 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -116,7 +116,10 @@ class RegistrationHandler:
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.registration.session_lifetime
- self.access_token_lifetime = hs.config.registration.access_token_lifetime
+ self.refreshable_access_token_lifetime = (
+ hs.config.registration.refreshable_access_token_lifetime
+ )
+ self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
init_counters_for_auth_provider("")
@@ -791,13 +794,13 @@ class RegistrationHandler:
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker.worker_app
- valid_until_ms = None
+ access_token_expiry = None
if self.session_lifetime is not None:
if is_guest:
raise Exception(
"session_lifetime is not currently implemented for guest access"
)
- valid_until_ms = self.clock.time_msec() + self.session_lifetime
+ access_token_expiry = self.clock.time_msec() + self.session_lifetime
refresh_token = None
refresh_token_id = None
@@ -806,23 +809,57 @@ class RegistrationHandler:
user_id, device_id, initial_display_name
)
if is_guest:
- assert valid_until_ms is None
+ assert access_token_expiry is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
if should_issue_refresh_token:
+ # A refreshable access token lifetime must be configured
+ # since we're told to issue a refresh token (the caller checks
+ # that this value is set before setting this flag).
+ assert self.refreshable_access_token_lifetime is not None
+
+ now_ms = self.clock.time_msec()
+
+ # Set the expiry time of the refreshable access token
+ access_token_expiry = now_ms + self.refreshable_access_token_lifetime
+
+ # Set the refresh token expiry time (if configured)
+ refresh_token_expiry = None
+ if self.refresh_token_lifetime is not None:
+ refresh_token_expiry = now_ms + self.refresh_token_lifetime
+
+ # Set an ultimate session expiry time (if configured)
+ ultimate_session_expiry_ts = None
+ if self.session_lifetime is not None:
+ ultimate_session_expiry_ts = now_ms + self.session_lifetime
+
+ # Also ensure that the issued tokens don't outlive the
+ # session.
+ # (It would be weird to configure a homeserver with a shorter
+ # session lifetime than token lifetime, but may as well handle
+ # it.)
+ access_token_expiry = min(
+ access_token_expiry, ultimate_session_expiry_ts
+ )
+ if refresh_token_expiry is not None:
+ refresh_token_expiry = min(
+ refresh_token_expiry, ultimate_session_expiry_ts
+ )
+
(
refresh_token,
refresh_token_id,
- ) = await self._auth_handler.get_refresh_token_for_user_id(
+ ) = await self._auth_handler.create_refresh_token_for_user_id(
user_id,
device_id=registered_device_id,
+ expiry_ts=refresh_token_expiry,
+ ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
- valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
- access_token = await self._auth_handler.get_access_token_for_user_id(
+ access_token = await self._auth_handler.create_access_token_for_user_id(
user_id,
device_id=registered_device_id,
- valid_until_ms=valid_until_ms,
+ valid_until_ms=access_token_expiry,
is_appservice_ghost=is_appservice_ghost,
refresh_token_id=refresh_token_id,
)
@@ -830,7 +867,7 @@ class RegistrationHandler:
return {
"device_id": registered_device_id,
"access_token": access_token,
- "valid_until_ms": valid_until_ms,
+ "valid_until_ms": access_token_expiry,
"refresh_token": refresh_token,
}
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f9a099c4f3..88053f9869 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -775,8 +775,11 @@ class RoomCreationHandler:
raise SynapseError(403, "Room visibility value not allowed.")
if is_public:
+ room_aliases = []
+ if room_alias:
+ room_aliases.append(room_alias.to_string())
if not self.config.roomdirectory.is_publishing_room_allowed(
- user_id, room_id, room_alias
+ user_id, room_id, room_aliases
):
# Let's just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 0723286383..f880aa93d2 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -221,6 +221,7 @@ class RoomBatchHandler:
action=membership,
content=event_dict["content"],
outlier=True,
+ historical=True,
prev_event_ids=[prev_event_id_for_state_chain],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
@@ -240,6 +241,7 @@ class RoomBatchHandler:
),
event_dict,
outlier=True,
+ historical=True,
prev_event_ids=[prev_event_id_for_state_chain],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 08244b690d..a6dbff637f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -268,6 +268,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: Optional[dict] = None,
require_consent: bool = True,
outlier: bool = False,
+ historical: bool = False,
) -> Tuple[str, int]:
"""
Internal membership update function to get an existing event or create
@@ -293,6 +294,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
Returns:
Tuple of event ID and stream ordering position
@@ -337,6 +341,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
auth_event_ids=auth_event_ids,
require_consent=require_consent,
outlier=outlier,
+ historical=historical,
)
prev_state_ids = await context.get_prev_state_ids()
@@ -433,6 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
new_room: bool = False,
require_consent: bool = True,
outlier: bool = False,
+ historical: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
@@ -454,6 +460,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
prev_event_ids: The event IDs to use as the prev events
auth_event_ids:
The event ids to use as the auth_events for the new event.
@@ -487,6 +496,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
new_room=new_room,
require_consent=require_consent,
outlier=outlier,
+ historical=historical,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
)
@@ -507,6 +517,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
new_room: bool = False,
require_consent: bool = True,
outlier: bool = False,
+ historical: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
@@ -530,6 +541,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
prev_event_ids: The event IDs to use as the prev events
auth_event_ids:
The event ids to use as the auth_events for the new event.
@@ -657,6 +671,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content=content,
require_consent=require_consent,
outlier=outlier,
+ historical=historical,
)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index d9764a7797..c06939e3ca 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,8 +36,9 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.types import JsonDict, Requester
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -93,11 +94,14 @@ class RoomSummaryHandler:
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
+ self._ratelimiter = Ratelimiter(
+ store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
+ )
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[
- Tuple[str, bool, Optional[int], Optional[int], Optional[str]]
+ Tuple[str, str, bool, Optional[int], Optional[int], Optional[str]]
] = ResponseCache(
hs.get_clock(),
"get_room_hierarchy",
@@ -249,7 +253,7 @@ class RoomSummaryHandler:
async def get_room_hierarchy(
self,
- requester: str,
+ requester: Requester,
requested_room_id: str,
suggested_only: bool = False,
max_depth: Optional[int] = None,
@@ -276,15 +280,24 @@ class RoomSummaryHandler:
Returns:
The JSON hierarchy dictionary.
"""
+ await self._ratelimiter.ratelimit(requester)
+
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
#
# This is due to the pagination process mutating internal state, attempting
# to process multiple requests for the same page will result in errors.
return await self._pagination_response_cache.wrap(
- (requested_room_id, suggested_only, max_depth, limit, from_token),
+ (
+ requester.user.to_string(),
+ requested_room_id,
+ suggested_only,
+ max_depth,
+ limit,
+ from_token,
+ ),
self._get_room_hierarchy,
- requester,
+ requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 22c6174821..1676ebd057 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -90,7 +90,7 @@ class FollowerTypingHandler:
self.wheel_timer = WheelTimer(bucket_size=5000)
@wrap_as_background_process("typing._handle_timeouts")
- def _handle_timeouts(self) -> None:
+ async def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 91ee5c8193..ceef57ad88 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,10 +20,25 @@ import os
import platform
import threading
import time
-from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
-from prometheus_client import Counter, Gauge, Histogram
+from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
from prometheus_client.core import (
REGISTRY,
CounterMetricFamily,
@@ -32,6 +47,7 @@ from prometheus_client.core import (
)
from twisted.internet import reactor
+from twisted.internet.base import ReactorBase
from twisted.python.threadpool import ThreadPool
import synapse
@@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
class RegistryProxy:
@staticmethod
- def collect():
+ def collect() -> Iterable[Metric]:
for metric in REGISTRY.collect():
if not metric.name.startswith("__"):
yield metric
@@ -74,7 +90,7 @@ class LaterGauge:
]
)
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
@@ -93,10 +109,10 @@ class LaterGauge:
yield g
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
self._register()
- def _register(self):
+ def _register(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
@@ -105,7 +121,12 @@ class LaterGauge:
all_gauges[self.name] = self
-class InFlightGauge:
+# `MetricsEntry` only makes sense when it is a `Protocol`,
+# but `Protocol` can't be used as a `TypeVar` bound.
+MetricsEntry = TypeVar("MetricsEntry")
+
+
+class InFlightGauge(Generic[MetricsEntry]):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time.
@@ -115,14 +136,19 @@ class InFlightGauge:
callbacks.
Args:
- name (str)
- desc (str)
- labels (list[str])
- sub_metrics (list[str]): A list of sub metrics that the callbacks
- will update.
+ name
+ desc
+ labels
+ sub_metrics: A list of sub metrics that the callbacks will update.
"""
- def __init__(self, name, desc, labels, sub_metrics):
+ def __init__(
+ self,
+ name: str,
+ desc: str,
+ labels: Sequence[str],
+ sub_metrics: Sequence[str],
+ ):
self.name = name
self.desc = desc
self.labels = labels
@@ -130,19 +156,25 @@ class InFlightGauge:
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
- self._metrics_class = attr.make_class(
+ self._metrics_class: Type[MetricsEntry] = attr.make_class(
"_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
)
# Counts number of in flight blocks for a given set of label values
- self._registrations: Dict = {}
+ self._registrations: Dict[
+ Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
+ ] = {}
# Protects access to _registrations
self._lock = threading.Lock()
self._register_with_collector()
- def register(self, key, callback):
+ def register(
+ self,
+ key: Tuple[str, ...],
+ callback: Callable[[MetricsEntry], None],
+ ) -> None:
"""Registers that we've entered a new block with labels `key`.
`callback` gets called each time the metrics are collected. The same
@@ -158,13 +190,17 @@ class InFlightGauge:
with self._lock:
self._registrations.setdefault(key, set()).add(callback)
- def unregister(self, key, callback):
+ def unregister(
+ self,
+ key: Tuple[str, ...],
+ callback: Callable[[MetricsEntry], None],
+ ) -> None:
"""Registers that we've exited a block with labels `key`."""
with self._lock:
self._registrations.setdefault(key, set()).discard(callback)
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
"""Called by prometheus client when it reads metrics.
Note: may be called by a separate thread.
@@ -200,7 +236,7 @@ class InFlightGauge:
gauge.add_metric(key, getattr(metrics, name))
yield gauge
- def _register_with_collector(self):
+ def _register_with_collector(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
@@ -230,7 +266,7 @@ class GaugeBucketCollector:
name: str,
documentation: str,
buckets: Iterable[float],
- registry=REGISTRY,
+ registry: CollectorRegistry = REGISTRY,
):
"""
Args:
@@ -257,12 +293,12 @@ class GaugeBucketCollector:
registry.register(self)
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
# Don't report metrics unless we've already collected some data
if self._metric is not None:
yield self._metric
- def update_data(self, values: Iterable[float]):
+ def update_data(self, values: Iterable[float]) -> None:
"""Update the data to be reported by the metric
The existing data is cleared, and each measurement in the input is assigned
@@ -304,7 +340,7 @@ class GaugeBucketCollector:
class CPUMetrics:
- def __init__(self):
+ def __init__(self) -> None:
ticks_per_sec = 100
try:
# Try and get the system config
@@ -314,7 +350,7 @@ class CPUMetrics:
self.ticks_per_sec = ticks_per_sec
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
if not HAVE_PROC_SELF_STAT:
return
@@ -364,7 +400,7 @@ gc_time = Histogram(
class GCCounts:
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
cm.add_metric([str(n)], m)
@@ -382,7 +418,7 @@ if not running_on_pypy:
class PyPyGCStats:
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
# @stats is a pretty-printer object with __str__() returning a nice table,
# plus some fields that contain data from that table.
@@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
class ReactorLastSeenMetric:
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen",
"Seconds since the Twisted reactor was last seen",
@@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
_last_gc = [0.0, 0.0, 0.0]
-def runUntilCurrentTimer(reactor, func):
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
@functools.wraps(func)
- def f(*args, **kwargs):
+ def f(*args: Any, **kwargs: Any) -> Any:
now = reactor.seconds()
num_pending = 0
@@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func):
return ret
- return f
+ return cast(F, f)
try:
@@ -677,5 +716,5 @@ __all__ = [
"start_http_server",
"LaterGauge",
"InFlightGauge",
- "BucketCollector",
+ "GaugeBucketCollector",
]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index bb9bcb5592..353d0a63b6 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -25,27 +25,25 @@ import math
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
-from typing import Dict, List
+from typing import Any, Dict, List, Type, Union
from urllib.parse import parse_qs, urlparse
-from prometheus_client import REGISTRY
+from prometheus_client import REGISTRY, CollectorRegistry
+from prometheus_client.core import Sample
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.util import caches
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
-INF = float("inf")
-MINUS_INF = float("-inf")
-
-
-def floatToGoString(d):
+def floatToGoString(d: Union[int, float]) -> str:
d = float(d)
- if d == INF:
+ if d == math.inf:
return "+Inf"
- elif d == MINUS_INF:
+ elif d == -math.inf:
return "-Inf"
elif math.isnan(d):
return "NaN"
@@ -60,7 +58,7 @@ def floatToGoString(d):
return s
-def sample_line(line, name):
+def sample_line(line: Sample, name: str) -> str:
if line.labels:
labelstr = "{{{0}}}".format(
",".join(
@@ -82,7 +80,7 @@ def sample_line(line, name):
return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
-def generate_latest(registry, emit_help=False):
+def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
# Trigger the cache metrics to be rescraped, which updates the common
# metrics but do not produce metrics themselves
@@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
registry = REGISTRY
- def do_GET(self):
+ def do_GET(self) -> None:
registry = self.registry
params = parse_qs(urlparse(self.path).query)
@@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(output)
- def log_message(self, format, *args):
+ def log_message(self, format: str, *args: Any) -> None:
"""Log nothing."""
@classmethod
- def factory(cls, registry):
+ def factory(cls, registry: CollectorRegistry) -> Type:
"""Returns a dynamic MetricsHandler class tied
to the passed registry.
"""
@@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
daemon_threads = True
-def start_http_server(port, addr="", registry=REGISTRY):
+def start_http_server(
+ port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
+) -> None:
"""Starts an HTTP server for prometheus metrics as a daemon thread"""
CustomMetricsHandler = MetricsHandler.factory(registry)
httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
@@ -252,10 +252,10 @@ class MetricsResource(Resource):
isLeaf = True
- def __init__(self, registry=REGISTRY):
+ def __init__(self, registry: CollectorRegistry = REGISTRY):
self.registry = registry
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
response = generate_latest(self.registry)
request.setHeader(b"Content-Length", str(len(response)))
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 2ab599a334..53c508af91 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,19 +15,37 @@
import logging
import threading
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Set, Union
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+ Set,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+from prometheus_client import Metric
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+ ContextResourceUsage,
+ LoggingContext,
+ PreserveLoggingContext,
+)
from synapse.logging.opentracing import (
SynapseTags,
noop_context_manager,
start_active_span,
)
-from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -116,7 +134,7 @@ class _Collector:
before they are returned.
"""
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
global _background_processes_active_since_last_scrape
# We swap out the _background_processes set with an empty one so that
@@ -144,12 +162,12 @@ REGISTRY.register(_Collector())
class _BackgroundProcess:
- def __init__(self, desc, ctx):
+ def __init__(self, desc: str, ctx: LoggingContext):
self.desc = desc
self._context = ctx
- self._reported_stats = None
+ self._reported_stats: Optional[ContextResourceUsage] = None
- def update_metrics(self):
+ def update_metrics(self) -> None:
"""Updates the metrics with values from this process."""
new_stats = self._context.get_resource_usage()
if self._reported_stats is None:
@@ -169,7 +187,16 @@ class _BackgroundProcess:
)
-def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs):
+R = TypeVar("R")
+
+
+def run_as_background_process(
+ desc: str,
+ func: Callable[..., Awaitable[Optional[R]]],
+ *args: Any,
+ bg_start_span: bool = True,
+ **kwargs: Any,
+) -> "defer.Deferred[Optional[R]]":
"""Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the
@@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
args: positional args for func
kwargs: keyword args for func
- Returns: Deferred which returns the result of func, but note that it does not
- follow the synapse logcontext rules.
+ Returns:
+ Deferred which returns the result of func, or `None` if func raises.
+ Note that the returned Deferred does not follow the synapse logcontext
+ rules.
"""
- async def run():
+ async def run() -> Optional[R]:
with _bg_metrics_lock:
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
@@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
else:
ctx = noop_context_manager()
with ctx:
- return await maybe_awaitable(func(*args, **kwargs))
+ return await func(*args, **kwargs)
except Exception:
logger.exception(
"Background process '%s' threw an exception",
desc,
)
+ return None
finally:
_background_process_in_flight_count.labels(desc).dec()
@@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
return defer.ensureDeferred(run())
-def wrap_as_background_process(desc):
+F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]])
+
+
+def wrap_as_background_process(desc: str) -> Callable[[F], F]:
"""Decorator that wraps a function that gets called as a background
process.
- Equivalent of calling the function with `run_as_background_process`
+ Equivalent to calling the function with `run_as_background_process`
"""
- def wrap_as_background_process_inner(func):
+ def wrap_as_background_process_inner(func: F) -> F:
@wraps(func)
- def wrap_as_background_process_inner_2(*args, **kwargs):
+ def wrap_as_background_process_inner_2(
+ *args: Any, **kwargs: Any
+ ) -> "defer.Deferred[Optional[R]]":
return run_as_background_process(desc, func, *args, **kwargs)
- return wrap_as_background_process_inner_2
+ return cast(F, wrap_as_background_process_inner_2)
return wrap_as_background_process_inner
@@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self)
- def start(self, rusage: "Optional[resource.struct_rusage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again)."""
super().start(rusage)
@@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext):
with _bg_metrics_lock:
_background_processes_active_since_last_scrape.add(self._proc)
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
"""Log context has finished."""
super().__exit__(type, value, traceback)
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 29ab6c0229..98ed9c0829 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -16,14 +16,16 @@ import ctypes
import logging
import os
import re
-from typing import Optional
+from typing import Iterable, Optional
+
+from prometheus_client import Metric
from synapse.metrics import REGISTRY, GaugeMetricFamily
logger = logging.getLogger(__name__)
-def _setup_jemalloc_stats():
+def _setup_jemalloc_stats() -> None:
"""Checks to see if jemalloc is loaded, and hooks up a collector to record
statistics exposed by jemalloc.
"""
@@ -135,7 +137,7 @@ def _setup_jemalloc_stats():
class JemallocCollector:
"""Metrics for internal jemalloc stats."""
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
_jemalloc_refresh_stats()
g = GaugeMetricFamily(
@@ -185,7 +187,7 @@ def _setup_jemalloc_stats():
logger.debug("Added jemalloc stats")
-def setup_jemalloc_stats():
+def setup_jemalloc_stats() -> None:
"""Try to setup jemalloc stats, if jemalloc is loaded."""
try:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ff79bc3c11..a8154168be 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,6 +24,7 @@ from typing import (
List,
Optional,
Tuple,
+ TypeVar,
Union,
)
@@ -35,7 +36,44 @@ from twisted.web.resource import Resource
from synapse.api.errors import SynapseError
from synapse.events import EventBase
-from synapse.events.presence_router import PresenceRouter
+from synapse.events.presence_router import (
+ GET_INTERESTED_USERS_CALLBACK,
+ GET_USERS_FOR_STATES_CALLBACK,
+ PresenceRouter,
+)
+from synapse.events.spamcheck import (
+ CHECK_EVENT_FOR_SPAM_CALLBACK,
+ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
+ CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
+ CHECK_USERNAME_FOR_SPAM_CALLBACK,
+ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK,
+ USER_MAY_CREATE_ROOM_CALLBACK,
+ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK,
+ USER_MAY_INVITE_CALLBACK,
+ USER_MAY_JOIN_ROOM_CALLBACK,
+ USER_MAY_PUBLISH_ROOM_CALLBACK,
+ USER_MAY_SEND_3PID_INVITE_CALLBACK,
+)
+from synapse.events.third_party_rules import (
+ CHECK_EVENT_ALLOWED_CALLBACK,
+ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK,
+ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK,
+ ON_CREATE_ROOM_CALLBACK,
+ ON_NEW_EVENT_CALLBACK,
+)
+from synapse.handlers.account_validity import (
+ IS_USER_EXPIRED_CALLBACK,
+ ON_LEGACY_ADMIN_REQUEST,
+ ON_LEGACY_RENEW_CALLBACK,
+ ON_LEGACY_SEND_MAIL_CALLBACK,
+ ON_USER_REGISTRATION_CALLBACK,
+)
+from synapse.handlers.auth import (
+ CHECK_3PID_AUTH_CALLBACK,
+ CHECK_AUTH_CALLBACK,
+ ON_LOGGED_OUT_CALLBACK,
+ AuthHandler,
+)
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@@ -44,10 +82,19 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import (
+ defer_to_thread,
+ make_deferred_yieldable,
+ run_in_background,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
+from synapse.storage.background_updates import (
+ DEFAULT_BATCH_SIZE_CALLBACK,
+ MIN_BATCH_SIZE_CALLBACK,
+ ON_UPDATE_CALLBACK,
+)
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
@@ -67,6 +114,9 @@ if TYPE_CHECKING:
from synapse.app.generic_worker import GenericWorkerSlavedStore
from synapse.server import HomeServer
+
+T = TypeVar("T")
+
"""
This package defines the 'stable' API which can be used by extension modules which
are loaded into Synapse.
@@ -114,7 +164,7 @@ class ModuleApi:
can register new users etc if necessary.
"""
- def __init__(self, hs: "HomeServer", auth_handler):
+ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
self._hs = hs
# TODO: Fix this type hint once the types for the data stores have been ironed
@@ -156,47 +206,139 @@ class ModuleApi:
#################################################################################
# The following methods should only be called during the module's initialisation.
- @property
- def register_spam_checker_callbacks(self):
+ def register_spam_checker_callbacks(
+ self,
+ check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+ user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
+ user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
+ user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
+ user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+ user_may_create_room_with_invites: Optional[
+ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+ ] = None,
+ user_may_create_room_alias: Optional[
+ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
+ ] = None,
+ user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
+ check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
+ check_registration_for_spam: Optional[
+ CHECK_REGISTRATION_FOR_SPAM_CALLBACK
+ ] = None,
+ check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ ) -> None:
"""Registers callbacks for spam checking capabilities.
Added in Synapse v1.37.0.
"""
- return self._spam_checker.register_callbacks
+ return self._spam_checker.register_callbacks(
+ check_event_for_spam=check_event_for_spam,
+ user_may_join_room=user_may_join_room,
+ user_may_invite=user_may_invite,
+ user_may_send_3pid_invite=user_may_send_3pid_invite,
+ user_may_create_room=user_may_create_room,
+ user_may_create_room_with_invites=user_may_create_room_with_invites,
+ user_may_create_room_alias=user_may_create_room_alias,
+ user_may_publish_room=user_may_publish_room,
+ check_username_for_spam=check_username_for_spam,
+ check_registration_for_spam=check_registration_for_spam,
+ check_media_file_for_spam=check_media_file_for_spam,
+ )
- @property
- def register_account_validity_callbacks(self):
+ def register_account_validity_callbacks(
+ self,
+ is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+ on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+ on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+ on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+ on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+ ) -> None:
"""Registers callbacks for account validity capabilities.
Added in Synapse v1.39.0.
"""
- return self._account_validity_handler.register_account_validity_callbacks
+ return self._account_validity_handler.register_account_validity_callbacks(
+ is_user_expired=is_user_expired,
+ on_user_registration=on_user_registration,
+ on_legacy_send_mail=on_legacy_send_mail,
+ on_legacy_renew=on_legacy_renew,
+ on_legacy_admin_request=on_legacy_admin_request,
+ )
- @property
- def register_third_party_rules_callbacks(self):
+ def register_third_party_rules_callbacks(
+ self,
+ check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+ on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
+ check_threepid_can_be_invited: Optional[
+ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+ ] = None,
+ check_visibility_can_be_modified: Optional[
+ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+ ] = None,
+ on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
+ ) -> None:
"""Registers callbacks for third party event rules capabilities.
Added in Synapse v1.39.0.
"""
- return self._third_party_event_rules.register_third_party_rules_callbacks
+ return self._third_party_event_rules.register_third_party_rules_callbacks(
+ check_event_allowed=check_event_allowed,
+ on_create_room=on_create_room,
+ check_threepid_can_be_invited=check_threepid_can_be_invited,
+ check_visibility_can_be_modified=check_visibility_can_be_modified,
+ on_new_event=on_new_event,
+ )
- @property
- def register_presence_router_callbacks(self):
+ def register_presence_router_callbacks(
+ self,
+ get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
+ get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
+ ) -> None:
"""Registers callbacks for presence router capabilities.
Added in Synapse v1.42.0.
"""
- return self._presence_router.register_presence_router_callbacks
+ return self._presence_router.register_presence_router_callbacks(
+ get_users_for_states=get_users_for_states,
+ get_interested_users=get_interested_users,
+ )
- @property
- def register_password_auth_provider_callbacks(self):
+ def register_password_auth_provider_callbacks(
+ self,
+ check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+ on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+ auth_checkers: Optional[
+ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
+ ] = None,
+ ) -> None:
"""Registers callbacks for password auth provider capabilities.
Added in Synapse v1.46.0.
"""
- return self._password_auth_provider.register_password_auth_provider_callbacks
+ return self._password_auth_provider.register_password_auth_provider_callbacks(
+ check_3pid_auth=check_3pid_auth,
+ on_logged_out=on_logged_out,
+ auth_checkers=auth_checkers,
+ )
+
+ def register_background_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Registers background update controller callbacks.
- def register_web_resource(self, path: str, resource: Resource):
+ Added in Synapse v1.49.0.
+ """
+
+ for db in self._hs.get_datastores().databases:
+ db.updates.register_update_controller_callbacks(
+ on_update=on_update,
+ default_batch_size=default_batch_size,
+ min_batch_size=min_batch_size,
+ )
+
+ def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path.
This function should be called during initialisation of the module.
@@ -216,7 +358,7 @@ class ModuleApi:
# The following methods can be called by the module at any point in time.
@property
- def http_client(self):
+ def http_client(self) -> SimpleHttpClient:
"""Allows making outbound HTTP requests to remote resources.
An instance of synapse.http.client.SimpleHttpClient
@@ -226,7 +368,7 @@ class ModuleApi:
return self._http_client
@property
- def public_room_list_manager(self):
+ def public_room_list_manager(self) -> "PublicRoomListManager":
"""Allows adding to, removing from and checking the status of rooms in the
public room list.
@@ -309,7 +451,7 @@ class ModuleApi:
"""
return await self._store.is_server_admin(UserID.from_string(user_id))
- def get_qualified_user_id(self, username):
+ def get_qualified_user_id(self, username: str) -> str:
"""Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
@@ -318,10 +460,10 @@ class ModuleApi:
Added in Synapse v0.25.0.
Args:
- username (str): provided user id
+ username: provided user id
Returns:
- str: qualified @user:id
+ qualified @user:id
"""
if username.startswith("@"):
return username
@@ -357,22 +499,27 @@ class ModuleApi:
"""
return await self._store.user_get_threepids(user_id)
- def check_user_exists(self, user_id):
+ def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
"""Check if user exists.
Added in Synapse v0.25.0.
Args:
- user_id (str): Complete @user:id
+ user_id: Complete @user:id
Returns:
- Deferred[str|None]: Canonical (case-corrected) user_id, or None
+ Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks
- def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
+ def register(
+ self,
+ localpart: str,
+ displayname: Optional[str] = None,
+ emails: Optional[List[str]] = None,
+ ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]:
"""Registers a new user with given localpart and optional displayname, emails.
Also returns an access token for the new user.
@@ -384,12 +531,12 @@ class ModuleApi:
Added in Synapse v0.25.0.
Args:
- localpart (str): The localpart of the new user.
- displayname (str|None): The displayname of the new user.
- emails (List[str]): Emails to bind to the new user.
+ localpart: The localpart of the new user.
+ displayname: The displayname of the new user.
+ emails: Emails to bind to the new user.
Returns:
- Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
+ a 2-tuple of (user_id, access_token)
"""
logger.warning(
"Using deprecated ModuleApi.register which creates a dummy user device."
@@ -399,23 +546,26 @@ class ModuleApi:
return user_id, access_token
def register_user(
- self, localpart, displayname=None, emails: Optional[List[str]] = None
- ):
+ self,
+ localpart: str,
+ displayname: Optional[str] = None,
+ emails: Optional[List[str]] = None,
+ ) -> "defer.Deferred[str]":
"""Registers a new user with given localpart and optional displayname, emails.
Added in Synapse v1.2.0.
Args:
- localpart (str): The localpart of the new user.
- displayname (str|None): The displayname of the new user.
- emails (List[str]): Emails to bind to the new user.
+ localpart: The localpart of the new user.
+ displayname: The displayname of the new user.
+ emails: Emails to bind to the new user.
Raises:
SynapseError if there is an error performing the registration. Check the
'errcode' property for more information on the reason for failure
Returns:
- defer.Deferred[str]: user_id
+ user_id
"""
return defer.ensureDeferred(
self._hs.get_registration_handler().register_user(
@@ -425,20 +575,25 @@ class ModuleApi:
)
)
- def register_device(self, user_id, device_id=None, initial_display_name=None):
+ def register_device(
+ self,
+ user_id: str,
+ device_id: Optional[str] = None,
+ initial_display_name: Optional[str] = None,
+ ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]":
"""Register a device for a user and generate an access token.
Added in Synapse v1.2.0.
Args:
- user_id (str): full canonical @user:id
- device_id (str|None): The device ID to check, or None to generate
+ user_id: full canonical @user:id
+ device_id: The device ID to check, or None to generate
a new one.
- initial_display_name (str|None): An optional display name for the
+ initial_display_name: An optional display name for the
device.
Returns:
- defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+ Tuple of device ID, access token, access token expiration time and refresh token
"""
return defer.ensureDeferred(
self._hs.get_registration_handler().register_device(
@@ -492,7 +647,9 @@ class ModuleApi:
)
@defer.inlineCallbacks
- def invalidate_access_token(self, access_token):
+ def invalidate_access_token(
+ self, access_token: str
+ ) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user
Added in Synapse v0.25.0.
@@ -524,14 +681,20 @@ class ModuleApi:
self._auth_handler.delete_access_token(access_token)
)
- def run_db_interaction(self, desc, func, *args, **kwargs):
+ def run_db_interaction(
+ self,
+ desc: str,
+ func: Callable[..., T],
+ *args: Any,
+ **kwargs: Any,
+ ) -> "defer.Deferred[T]":
"""Run a function with a database connection
Added in Synapse v0.25.0.
Args:
- desc (str): description for the transaction, for metrics etc
- func (func): function to be run. Passed a database cursor object
+ desc: description for the transaction, for metrics etc
+ func: function to be run. Passed a database cursor object
as well as *args and **kwargs
*args: positional args to be passed to func
**kwargs: named args to be passed to func
@@ -545,7 +708,7 @@ class ModuleApi:
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
- ):
+ ) -> None:
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -575,7 +738,7 @@ class ModuleApi:
client_redirect_url: str,
new_user: bool = False,
auth_provider_id: str = "<unknown>",
- ):
+ ) -> None:
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -814,11 +977,11 @@ class ModuleApi:
self,
f: Callable,
msec: float,
- *args,
+ *args: object,
desc: Optional[str] = None,
run_on_all_instances: bool = False,
- **kwargs,
- ):
+ **kwargs: object,
+ ) -> None:
"""Wraps a function as a background process and calls it repeatedly.
NOTE: Will only run on the instance that is configured to run
@@ -859,13 +1022,18 @@ class ModuleApi:
f,
)
+ async def sleep(self, seconds: float) -> None:
+ """Sleeps for the given number of seconds."""
+
+ await self._clock.sleep(seconds)
+
async def send_mail(
self,
recipient: str,
subject: str,
html: str,
text: str,
- ):
+ ) -> None:
"""Send an email on behalf of the homeserver.
Added in Synapse v1.39.0.
@@ -903,7 +1071,7 @@ class ModuleApi:
A list containing the loaded templates, with the orders matching the one of
the filenames parameter.
"""
- return self._hs.config.read_templates(
+ return self._hs.config.server.read_templates(
filenames,
(td for td in (self.custom_template_dir, custom_template_directory) if td),
)
@@ -1013,6 +1181,26 @@ class ModuleApi:
return {key: state_events[event_id] for key, event_id in state_ids.items()}
+ async def defer_to_thread(
+ self,
+ f: Callable[..., T],
+ *args: Any,
+ **kwargs: Any,
+ ) -> T:
+ """Runs the given function in a separate thread from Synapse's thread pool.
+
+ Added in Synapse v1.49.0.
+
+ Args:
+ f: The function to run.
+ args: The function's arguments.
+ kwargs: The function's keyword arguments.
+
+ Returns:
+ The return value of the function once ran in a thread.
+ """
+ return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 154e5b7028..7d26954244 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -86,7 +86,7 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
"cryptography>=3.4.7",
- "ijson>=3.0",
+ "ijson>=3.1",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
from typing import List, Optional, Tuple
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+ """Tracks the "current" stream ID of a stream with a single writer.
+
+ See `AbstractStreamIdTracker` for more details.
+
+ Note that this class does not work correctly when there are multiple
+ writers.
+ """
+
def __init__(
self,
db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:
- """
-
- Returns:
- int
- """
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- # We assert this for the benefit of mypy
- assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
import attr
@@ -157,7 +157,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
- event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+ event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)
@@ -191,7 +191,7 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
- ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+ ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 36cfd1e4e2..d7621981d9 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -28,6 +28,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.background_updates import (
BackgroundUpdateEnabledRestServlet,
BackgroundUpdateRestServlet,
+ BackgroundUpdateStartJobRestServlet,
)
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
@@ -46,6 +47,7 @@ from synapse.rest.admin.registration_tokens import (
RegistrationTokenRestServlet,
)
from synapse.rest.admin.rooms import (
+ BlockRoomRestServlet,
DeleteRoomStatusByDeleteIdRestServlet,
DeleteRoomStatusByRoomIdRestServlet,
ForwardExtremitiesRestServlet,
@@ -224,6 +226,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
+ BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
@@ -261,6 +264,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SendServerNoticeServlet(hs).register(http_server)
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server)
+ BackgroundUpdateStartJobRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 0d0183bf20..479672d4d5 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.types import JsonDict
@@ -29,37 +34,36 @@ logger = logging.getLogger(__name__)
class BackgroundUpdateEnabledRestServlet(RestServlet):
"""Allows temporarily disabling background updates"""
- PATTERNS = admin_patterns("/background_updates/enabled")
+ PATTERNS = admin_patterns("/background_updates/enabled$")
def __init__(self, hs: "HomeServer"):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- self.data_stores = hs.get_datastores()
+ self._auth = hs.get_auth()
+ self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester.user)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
- enabled = all(db.updates.enabled for db in self.data_stores.databases)
+ enabled = all(db.updates.enabled for db in self._data_stores.databases)
- return 200, {"enabled": enabled}
+ return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester.user)
body = parse_json_object_from_request(request)
enabled = body.get("enabled", True)
if not isinstance(enabled, bool):
- raise SynapseError(400, "'enabled' parameter must be a boolean")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean"
+ )
- for db in self.data_stores.databases:
+ for db in self._data_stores.databases:
db.updates.enabled = enabled
# If we're re-enabling them ensure that we start the background
@@ -67,32 +71,29 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
if enabled:
db.updates.start_doing_background_updates()
- return 200, {"enabled": enabled}
+ return HTTPStatus.OK, {"enabled": enabled}
class BackgroundUpdateRestServlet(RestServlet):
"""Fetch information about background updates"""
- PATTERNS = admin_patterns("/background_updates/status")
+ PATTERNS = admin_patterns("/background_updates/status$")
def __init__(self, hs: "HomeServer"):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- self.data_stores = hs.get_datastores()
+ self._auth = hs.get_auth()
+ self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester.user)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
- enabled = all(db.updates.enabled for db in self.data_stores.databases)
+ enabled = all(db.updates.enabled for db in self._data_stores.databases)
current_updates = {}
- for db in self.data_stores.databases:
+ for db in self._data_stores.databases:
update = db.updates.get_current_update()
if not update:
continue
@@ -104,4 +105,72 @@ class BackgroundUpdateRestServlet(RestServlet):
"average_items_per_ms": update.average_items_per_ms(),
}
- return 200, {"enabled": enabled, "current_updates": current_updates}
+ return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates}
+
+
+class BackgroundUpdateStartJobRestServlet(RestServlet):
+ """Allows to start specific background updates"""
+
+ PATTERNS = admin_patterns("/background_updates/start_job")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastore()
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester.user)
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["job_name"])
+
+ job_name = body["job_name"]
+
+ if job_name == "populate_stats_process_rooms":
+ jobs = [
+ {
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ },
+ ]
+ elif job_name == "regenerate_directory":
+ jobs = [
+ {
+ "update_name": "populate_user_directory_createtables",
+ "progress_json": "{}",
+ "depends_on": "",
+ },
+ {
+ "update_name": "populate_user_directory_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_createtables",
+ },
+ {
+ "update_name": "populate_user_directory_process_users",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_process_rooms",
+ },
+ {
+ "update_name": "populate_user_directory_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_process_users",
+ },
+ ]
+ else:
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name")
+
+ try:
+ await self._store.db_pool.simple_insert_many(
+ table="background_updates",
+ values=jobs,
+ desc=f"admin_api_run_{job_name}",
+ )
+ except self._store.db_pool.engine.module.IntegrityError:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Job %s is already in queue of background updates." % (job_name,),
+ )
+
+ self._store.db_pool.updates.start_doing_background_updates()
+
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 37cb4d0796..a89dda1ba5 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -448,7 +448,7 @@ class RoomStateRestServlet(RestServlet):
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
- bundle_aggregations=False,
+ bundle_relations=False,
)
ret = {"state": room_state}
@@ -778,7 +778,70 @@ class RoomEventContextServlet(RestServlet):
results["state"],
time_now,
# No need to bundle aggregations for state events
- bundle_aggregations=False,
+ bundle_relations=False,
)
return 200, results
+
+
+class BlockRoomRestServlet(RestServlet):
+ """
+ Manage blocking of rooms.
+ On PUT: Add or remove a room from blocking list.
+ On GET: Get blocking status of room and user who has blocked this room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastore()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ if not RoomID.is_valid(room_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+ )
+
+ blocked_by = await self._store.room_is_blocked_by(room_id)
+ # Test `not None` if `user_id` is an empty string
+ # if someone add manually an entry in database
+ if blocked_by is not None:
+ response = {"block": True, "user_id": blocked_by}
+ else:
+ response = {"block": False}
+
+ return HTTPStatus.OK, response
+
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+
+ if not RoomID.is_valid(room_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+ )
+
+ assert_params_in_dict(content, ["block"])
+ block = content.get("block")
+ if not isinstance(block, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'block' must be a boolean.",
+ Codes.BAD_JSON,
+ )
+
+ if block:
+ await self._store.block_room(room_id, requester.user.to_string())
+ else:
+ await self._store.unblock_room(room_id)
+
+ return HTTPStatus.OK, {"block": block}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 23a8bf1fdb..ccd9a2a175 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet):
if auth_user.to_string() == user_id:
raise SynapseError(400, "Cannot use admin API to login as self")
- token = await self.auth_handler.get_access_token_for_user_id(
+ token = await self.auth_handler.create_access_token_for_user_id(
user_id=auth_user.to_string(),
device_id=None,
valid_until_ms=valid_until_ms,
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index a0971ce994..b4cb90cb76 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
def client_patterns(
path_regex: str,
- releases: Iterable[int] = (0,),
+ releases: Iterable[str] = ("r0", "v3"),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
@@ -52,7 +52,7 @@ def client_patterns(
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
- new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
+ new_prefix = CLIENT_API_PREFIX + f"/{release}"
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 7281b2ee29..730c18f08f 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet):
}
"""
- PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
+ PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",))
def __init__(self, hs: "HomeServer"):
super().__init__()
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 467444a041..09f378f919 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -14,7 +14,17 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from typing_extensions import TypedDict
@@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
- parse_boolean,
parse_bytes_from_args,
parse_json_object_from_request,
parse_string,
@@ -72,6 +81,7 @@ class LoginRestServlet(RestServlet):
# JWT configuration variables.
self.jwt_enabled = hs.config.jwt.jwt_enabled
self.jwt_secret = hs.config.jwt.jwt_secret
+ self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
@@ -80,7 +90,9 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2.saml2_enabled
self.cas_enabled = hs.config.cas.cas_enabled
self.oidc_enabled = hs.config.oidc.oidc_enabled
- self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
+ self._msc2918_enabled = (
+ hs.config.registration.refreshable_access_token_lifetime is not None
+ )
self.auth = hs.get_auth()
@@ -152,11 +164,14 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
- # Check if this login should also issue a refresh token, as per
- # MSC2918
- should_issue_refresh_token = parse_boolean(
- request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+ # Check if this login should also issue a refresh token, as per MSC2918
+ should_issue_refresh_token = login_submission.get(
+ "org.matrix.msc2918.refresh_token", False
)
+ if not isinstance(should_issue_refresh_token, bool):
+ raise SynapseError(
+ 400, "`org.matrix.msc2918.refresh_token` should be true or false."
+ )
else:
should_issue_refresh_token = False
@@ -413,7 +428,7 @@ class LoginRestServlet(RestServlet):
errcode=Codes.FORBIDDEN,
)
- user = payload.get("sub", None)
+ user = payload.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
@@ -452,7 +467,10 @@ class RefreshTokenServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
self._clock = hs.get_clock()
- self.access_token_lifetime = hs.config.registration.access_token_lifetime
+ self.refreshable_access_token_lifetime = (
+ hs.config.registration.refreshable_access_token_lifetime
+ )
+ self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
@@ -462,20 +480,33 @@ class RefreshTokenServlet(RestServlet):
if not isinstance(token, str):
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
- valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
- access_token, refresh_token = await self._auth_handler.refresh_token(
- token, valid_until_ms
- )
- expires_in_ms = valid_until_ms - self._clock.time_msec()
- return (
- 200,
- {
- "access_token": access_token,
- "refresh_token": refresh_token,
- "expires_in_ms": expires_in_ms,
- },
+ now = self._clock.time_msec()
+ access_valid_until_ms = None
+ if self.refreshable_access_token_lifetime is not None:
+ access_valid_until_ms = now + self.refreshable_access_token_lifetime
+ refresh_valid_until_ms = None
+ if self.refresh_token_lifetime is not None:
+ refresh_valid_until_ms = now + self.refresh_token_lifetime
+
+ (
+ access_token,
+ refresh_token,
+ actual_access_token_expiry,
+ ) = await self._auth_handler.refresh_token(
+ token, access_valid_until_ms, refresh_valid_until_ms
)
+ response: Dict[str, Union[str, int]] = {
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ }
+
+ # expires_in_ms is only present if the token expires
+ if actual_access_token_expiry is not None:
+ response["expires_in_ms"] = actual_access_token_expiry - now
+
+ return 200, response
+
class SsoRedirectServlet(RestServlet):
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
@@ -561,7 +592,7 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
- if hs.config.registration.access_token_lifetime is not None:
+ if hs.config.registration.refreshable_access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas.cas_enabled:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index bf3cb34146..11fd6cd24d 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
- parse_boolean,
parse_json_object_from_request,
parse_string,
)
@@ -420,7 +419,9 @@ class RegisterRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.registration.enable_registration
- self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
+ self._msc2918_enabled = (
+ hs.config.registration.refreshable_access_token_lifetime is not None
+ )
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -447,9 +448,13 @@ class RegisterRestServlet(RestServlet):
if self._msc2918_enabled:
# Check if this registration should also issue a refresh token, as
# per MSC2918
- should_issue_refresh_token = parse_boolean(
- request, name="org.matrix.msc2918.refresh_token", default=False
+ should_issue_refresh_token = body.get(
+ "org.matrix.msc2918.refresh_token", False
)
+ if not isinstance(should_issue_refresh_token, bool):
+ raise SynapseError(
+ 400, "`org.matrix.msc2918.refresh_token` should be true or false."
+ )
else:
should_issue_refresh_token = False
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 184cfbe196..45e9f1dd90 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -224,17 +224,17 @@ class RelationPaginationServlet(RestServlet):
)
now = self.clock.time_msec()
- # We set bundle_aggregations to False when retrieving the original
+ # We set bundle_relations to False when retrieving the original
# event because we want the content before relations were applied to
# it.
original_event = await self._event_serializer.serialize_event(
- event, now, bundle_aggregations=False
+ event, now, bundle_relations=False
)
# Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them
# here.
serialized_events = await self._event_serializer.serialize_events(
- events, now, bundle_aggregations=False
+ events, now, bundle_relations=False
)
return_value = pagination_chunk.to_dict()
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 03a353d53c..73d0f7c950 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -719,7 +719,7 @@ class RoomEventContextServlet(RestServlet):
results["state"],
time_now,
# No need to bundle aggregations for state events
- bundle_aggregations=False,
+ bundle_relations=False,
)
return 200, results
@@ -1138,12 +1138,12 @@ class RoomSpaceSummaryRestServlet(RestServlet):
class RoomHierarchyRestServlet(RestServlet):
- PATTERNS = (
+ PATTERNS = [
re.compile(
- "^/_matrix/client/unstable/org.matrix.msc2946"
+ "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
"/rooms/(?P<room_id>[^/]*)/hierarchy$"
),
- )
+ ]
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -1168,7 +1168,7 @@ class RoomHierarchyRestServlet(RestServlet):
)
return 200, await self._room_summary_handler.get_room_hierarchy(
- requester.user.to_string(),
+ requester,
room_id,
suggested_only=parse_boolean(request, "suggested_only", default=False),
max_depth=max_depth,
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8c0fdb1940..b6a2485732 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -522,7 +522,7 @@ class SyncRestServlet(RestServlet):
time_now=time_now,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
- bundle_aggregations=False,
+ bundle_relations=False,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 014fa893d6..9b40fd8a6c 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.util.stringutils import is_ascii
+from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
logger = logging.getLogger(__name__)
@@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
+ """Parses the server name, media ID and optional file name from the request URI
+
+ Also performs some rough validation on the server name.
+
+ Args:
+ request: The `Request`.
+
+ Returns:
+ A tuple containing the parsed server name, media ID and optional file name.
+
+ Raises:
+ SynapseError(404): if parsing or validation fail for any reason
+ """
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath: List[bytes] = request.postpath # type: ignore
@@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8")
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+
file_name = None
if len(postpath) > 2:
try:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index bec77088ee..c0e15c6513 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,8 @@
import functools
import os
import re
-from typing import Any, Callable, List, TypeVar, cast
+import string
+from typing import Any, Callable, List, TypeVar, Union, cast
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@@ -37,6 +38,85 @@ def _wrap_in_base_path(func: F) -> F:
return cast(F, _wrapped)
+GetPathMethod = TypeVar(
+ "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
+)
+
+
+def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod:
+ """Wraps a path-returning method to check that the returned path(s) do not escape
+ the media store directory.
+
+ The check is not expected to ever fail, unless `func` is missing a call to
+ `_validate_path_component`, or `_validate_path_component` is buggy.
+
+ Args:
+ func: The `MediaFilePaths` method to wrap. The method may return either a single
+ path, or a list of paths. Returned paths may be either absolute or relative.
+
+ Returns:
+ The method, wrapped with a check to ensure that the returned path(s) lie within
+ the media store directory. Raises a `ValueError` if the check fails.
+ """
+
+ @functools.wraps(func)
+ def _wrapped(
+ self: "MediaFilePaths", *args: Any, **kwargs: Any
+ ) -> Union[str, List[str]]:
+ path_or_paths = func(self, *args, **kwargs)
+
+ if isinstance(path_or_paths, list):
+ paths_to_check = path_or_paths
+ else:
+ paths_to_check = [path_or_paths]
+
+ for path in paths_to_check:
+ # path may be an absolute or relative path, depending on the method being
+ # wrapped. When "appending" an absolute path, `os.path.join` discards the
+ # previous path, which is desired here.
+ normalized_path = os.path.normpath(os.path.join(self.real_base_path, path))
+ if (
+ os.path.commonpath([normalized_path, self.real_base_path])
+ != self.real_base_path
+ ):
+ raise ValueError(f"Invalid media store path: {path!r}")
+
+ return path_or_paths
+
+ return cast(GetPathMethod, _wrapped)
+
+
+ALLOWED_CHARACTERS = set(
+ string.ascii_letters
+ + string.digits
+ + "_-"
+ + ".[]:" # Domain names, IPv6 addresses and ports in server names
+)
+FORBIDDEN_NAMES = {
+ "",
+ os.path.curdir, # "." for the current platform
+ os.path.pardir, # ".." for the current platform
+}
+
+
+def _validate_path_component(name: str) -> str:
+ """Checks that the given string can be safely used as a path component
+
+ Args:
+ name: The path component to check.
+
+ Returns:
+ The path component if valid.
+
+ Raises:
+ ValueError: If `name` cannot be safely used as a path component.
+ """
+ if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
+ raise ValueError(f"Invalid path component: {name!r}")
+
+ return name
+
+
class MediaFilePaths:
"""Describes where files are stored on disk.
@@ -48,22 +128,46 @@ class MediaFilePaths:
def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
+ # The media store directory, with all symlinks resolved.
+ self.real_base_path = os.path.realpath(primary_base_path)
+
+ # Refuse to initialize if paths cannot be validated correctly for the current
+ # platform.
+ assert os.path.sep not in ALLOWED_CHARACTERS
+ assert os.path.altsep not in ALLOWED_CHARACTERS
+ # On Windows, paths have all sorts of weirdness which `_validate_path_component`
+ # does not consider. In any case, the remote media store can't work correctly
+ # for certain homeservers there, since ":"s aren't allowed in paths.
+ assert os.name == "posix"
+
+ @_wrap_with_jail_check
def local_media_filepath_rel(self, media_id: str) -> str:
- return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
+ return os.path.join(
+ "local_content",
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
+ )
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
+ @_wrap_with_jail_check
def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
- "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
+ "local_thumbnails",
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
+ _validate_path_component(file_name),
)
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+ @_wrap_with_jail_check
def local_media_thumbnail_dir(self, media_id: str) -> str:
"""
Retrieve the local store path of thumbnails of a given media_id
@@ -76,18 +180,24 @@ class MediaFilePaths:
return os.path.join(
self.base_path,
"local_thumbnails",
- media_id[0:2],
- media_id[2:4],
- media_id[4:],
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
)
+ @_wrap_with_jail_check
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
- "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
+ "remote_content",
+ _validate_path_component(server_name),
+ _validate_path_component(file_id[0:2]),
+ _validate_path_component(file_id[2:4]),
+ _validate_path_component(file_id[4:]),
)
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
+ @_wrap_with_jail_check
def remote_media_thumbnail_rel(
self,
server_name: str,
@@ -101,11 +211,11 @@ class MediaFilePaths:
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
- server_name,
- file_id[0:2],
- file_id[2:4],
- file_id[4:],
- file_name,
+ _validate_path_component(server_name),
+ _validate_path_component(file_id[0:2]),
+ _validate_path_component(file_id[2:4]),
+ _validate_path_component(file_id[4:]),
+ _validate_path_component(file_name),
)
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
@@ -113,6 +223,7 @@ class MediaFilePaths:
# Legacy path that was used to store thumbnails previously.
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
+ @_wrap_with_jail_check
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
) -> str:
@@ -120,43 +231,66 @@ class MediaFilePaths:
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
"remote_thumbnail",
- server_name,
- file_id[0:2],
- file_id[2:4],
- file_id[4:],
- file_name,
+ _validate_path_component(server_name),
+ _validate_path_component(file_id[0:2]),
+ _validate_path_component(file_id[2:4]),
+ _validate_path_component(file_id[4:]),
+ _validate_path_component(file_name),
)
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
- server_name,
- file_id[0:2],
- file_id[2:4],
- file_id[4:],
+ _validate_path_component(server_name),
+ _validate_path_component(file_id[0:2]),
+ _validate_path_component(file_id[2:4]),
+ _validate_path_component(file_id[4:]),
)
+ @_wrap_with_jail_check
def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
- return os.path.join("url_cache", media_id[:10], media_id[11:])
+ return os.path.join(
+ "url_cache",
+ _validate_path_component(media_id[:10]),
+ _validate_path_component(media_id[11:]),
+ )
else:
- return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
+ return os.path.join(
+ "url_cache",
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
+ )
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
+ @_wrap_with_jail_check
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
- return [os.path.join(self.base_path, "url_cache", media_id[:10])]
+ return [
+ os.path.join(
+ self.base_path, "url_cache", _validate_path_component(media_id[:10])
+ )
+ ]
else:
return [
- os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
- os.path.join(self.base_path, "url_cache", media_id[0:2]),
+ os.path.join(
+ self.base_path,
+ "url_cache",
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ ),
+ os.path.join(
+ self.base_path, "url_cache", _validate_path_component(media_id[0:2])
+ ),
]
+ @_wrap_with_jail_check
def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
@@ -168,37 +302,46 @@ class MediaFilePaths:
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
- "url_cache_thumbnails", media_id[:10], media_id[11:], file_name
+ "url_cache_thumbnails",
+ _validate_path_component(media_id[:10]),
+ _validate_path_component(media_id[11:]),
+ _validate_path_component(file_name),
)
else:
return os.path.join(
"url_cache_thumbnails",
- media_id[0:2],
- media_id[2:4],
- media_id[4:],
- file_name,
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
+ _validate_path_component(file_name),
)
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
+ @_wrap_with_jail_check
def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
- return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:])
+ return os.path.join(
+ "url_cache_thumbnails",
+ _validate_path_component(media_id[:10]),
+ _validate_path_component(media_id[11:]),
+ )
else:
return os.path.join(
"url_cache_thumbnails",
- media_id[0:2],
- media_id[2:4],
- media_id[4:],
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
)
url_cache_thumbnail_directory = _wrap_in_base_path(
url_cache_thumbnail_directory_rel
)
+ @_wrap_with_jail_check
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
@@ -206,21 +349,35 @@ class MediaFilePaths:
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
- self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
+ self.base_path,
+ "url_cache_thumbnails",
+ _validate_path_component(media_id[:10]),
+ _validate_path_component(media_id[11:]),
+ ),
+ os.path.join(
+ self.base_path,
+ "url_cache_thumbnails",
+ _validate_path_component(media_id[:10]),
),
- os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
]
else:
return [
os.path.join(
self.base_path,
"url_cache_thumbnails",
- media_id[0:2],
- media_id[2:4],
- media_id[4:],
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ _validate_path_component(media_id[4:]),
),
os.path.join(
- self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
+ self.base_path,
+ "url_cache_thumbnails",
+ _validate_path_component(media_id[0:2]),
+ _validate_path_component(media_id[2:4]),
+ ),
+ os.path.join(
+ self.base_path,
+ "url_cache_thumbnails",
+ _validate_path_component(media_id[0:2]),
),
- os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
]
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1605411b00..446204dbe5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -764,7 +764,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6edadea550..499a328201 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,6 +17,7 @@ import logging
from typing import (
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -44,7 +45,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+ state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self,
stream_name: str,
instance_name: str,
- token: StreamToken,
+ token: int,
rows: Iterable[Any],
) -> None:
pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b9a8ca997e..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+ TYPE_CHECKING,
+ AsyncContextManager,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+)
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
from . import engines
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+ """A handler for a given background update.
+
+ Attributes:
+ callback: The function to call to make progress on the background
+ update.
+ oneshot: Wether the update is likely to happen all in one go, ignoring
+ the supplied target duration, e.g. index creation. This is used by
+ the update controller to help correctly schedule the update.
+ """
+
+ callback: Callable[[JsonDict, int], Awaitable[int]]
+ oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+ BACKGROUND_UPDATE_INTERVAL_MS = 1000
+ BACKGROUND_UPDATE_DURATION_MS = 100
+
+ def __init__(self, sleep: bool, clock: Clock):
+ self._sleep = sleep
+ self._clock = clock
+
+ async def __aenter__(self) -> int:
+ if self._sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+ return self.BACKGROUND_UPDATE_DURATION_MS
+
+ async def __aexit__(self, *exc) -> None:
+ pass
+
+
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
@@ -82,22 +131,24 @@ class BackgroundUpdater:
process and autotuning the batch size.
"""
- MINIMUM_BACKGROUND_BATCH_SIZE = 100
+ MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
- BACKGROUND_UPDATE_INTERVAL_MS = 1000
- BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
+ self._database_name = database.name()
+
# if a background update is currently running, its name.
self._current_background_update: Optional[str] = None
+ self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+ self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+ self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
- self._background_update_handlers: Dict[
- str, Callable[[JsonDict, int], Awaitable[int]]
- ] = {}
+ self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
self._all_done = False
# Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
+ def register_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Register callbacks from a module for each hook."""
+ if self._on_update_callback is not None:
+ logger.warning(
+ "More than one module tried to register callbacks for controlling"
+ " background updates. Only the callbacks registered by the first module"
+ " (in order of appearance in Synapse's configuration file) that tried to"
+ " do so will be called."
+ )
+
+ return
+
+ self._on_update_callback = on_update
+
+ if default_batch_size is not None:
+ self._default_batch_size_callback = default_batch_size
+
+ if min_batch_size is not None:
+ self._min_batch_size_callback = min_batch_size
+
+ def _get_context_manager_for_update(
+ self,
+ sleep: bool,
+ update_name: str,
+ database_name: str,
+ oneshot: bool,
+ ) -> AsyncContextManager[int]:
+ """Get a context manager to run a background update with.
+
+ If a module has registered a `update_handler` callback, use the context manager
+ it returns.
+
+ Otherwise, returns a context manager that will return a default value, optionally
+ sleeping if needed.
+
+ Args:
+ sleep: Whether we can sleep between updates.
+ update_name: The name of the update.
+ database_name: The name of the database the update is being run on.
+ oneshot: Whether the update will complete all in one go, e.g. index creation.
+ In such cases the returned target duration is ignored.
+
+ Returns:
+ The target duration in milliseconds that the background update should run for.
+
+ Note: this is a *target*, and an iteration may take substantially longer or
+ shorter.
+ """
+ if self._on_update_callback is not None:
+ return self._on_update_callback(update_name, database_name, oneshot)
+
+ return _BackgroundUpdateContextManager(sleep, self._clock)
+
+ async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+ """The batch size to use for the first iteration of a new background
+ update.
+ """
+ if self._default_batch_size_callback is not None:
+ return await self._default_batch_size_callback(update_name, database_name)
+
+ return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+ async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+ """A lower bound on the batch size of a new background update.
+
+ Used to ensure that progress is always made. Must be greater than 0.
+ """
+ if self._min_batch_size_callback is not None:
+ return await self._min_batch_size_callback(update_name, database_name)
+
+ return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@@ -122,6 +250,8 @@ class BackgroundUpdater:
def start_doing_background_updates(self) -> None:
if self.enabled:
+ # if we start a new background update, not all updates are done.
+ self._all_done = False
run_as_background_process("background_updates", self.run_background_updates)
async def run_background_updates(self, sleep: bool = True) -> None:
@@ -133,13 +263,8 @@ class BackgroundUpdater:
try:
logger.info("Starting background schema updates")
while self.enabled:
- if sleep:
- await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
try:
- result = await self.do_next_background_update(
- self.BACKGROUND_UPDATE_DURATION_MS
- )
+ result = await self.do_next_background_update(sleep)
except Exception:
logger.exception("Error doing update")
else:
@@ -201,13 +326,15 @@ class BackgroundUpdater:
return not update_exists
- async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+ async def do_next_background_update(self, sleep: bool = True) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args:
- desired_duration_ms: How long we want to spend updating.
+ sleep: Whether to limit how quickly we run background updates or
+ not.
+
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -250,7 +377,19 @@ class BackgroundUpdater:
self._current_background_update = upd["update_name"]
- await self._do_background_update(desired_duration_ms)
+ # We have a background update to run, otherwise we would have returned
+ # early.
+ assert self._current_background_update is not None
+ update_info = self._background_update_handlers[self._current_background_update]
+
+ async with self._get_context_manager_for_update(
+ sleep=sleep,
+ update_name=self._current_background_update,
+ database_name=self._database_name,
+ oneshot=update_info.oneshot,
+ ) as desired_duration_ms:
+ await self._do_background_update(desired_duration_ms)
+
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -258,7 +397,7 @@ class BackgroundUpdater:
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
- update_handler = self._background_update_handlers[update_name]
+ update_handler = self._background_update_handlers[update_name].callback
performance = self._background_update_performance.get(update_name)
@@ -271,9 +410,14 @@ class BackgroundUpdater:
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
- batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+ batch_size = max(
+ batch_size,
+ await self._min_batch_size(update_name, self._database_name),
+ )
else:
- batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+ batch_size = await self._default_batch_size(
+ update_name, self._database_name
+ )
progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
@@ -292,6 +436,8 @@ class BackgroundUpdater:
duration_ms = time_stop - time_start
+ performance.update(items_updated, duration_ms)
+
logger.info(
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -304,8 +450,6 @@ class BackgroundUpdater:
batch_size,
)
- performance.update(items_updated, duration_ms)
-
return len(self._background_update_performance)
def register_background_update_handler(
@@ -329,7 +473,9 @@ class BackgroundUpdater:
update_name: The name of the update that this code handles.
update_handler: The function that does the update.
"""
- self._background_update_handlers[update_name] = update_handler
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ update_handler
+ )
def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
@@ -451,7 +597,9 @@ class BackgroundUpdater:
await self._end_background_update(update_name)
return 1
- self.register_background_update_handler(update_name, updater)
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d4cab69ebf..0693d39006 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
-_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
@@ -235,7 +235,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
+ def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -247,7 +247,7 @@ class LoggingTransaction:
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(
- self, callback: Callable[..., None], *args: Any, **kwargs: Any
+ self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 7c0f953365..ab8766c75b 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -599,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
+ REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -614,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
- self.db_pool.updates.register_background_update_handler(
- self.REMOVE_DELETED_DEVICES,
- self._remove_deleted_devices_from_device_inbox,
+ # Used to be a background update that deletes all device_inboxes for deleted
+ # devices.
+ self.db_pool.updates.register_noop_background_update(
+ self.REMOVE_DELETED_DEVICES
)
+ # Used to be a background update that deletes all device_inboxes for hidden
+ # devices.
+ self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
self.db_pool.updates.register_background_update_handler(
- self.REMOVE_HIDDEN_DEVICES,
- self._remove_hidden_devices_from_device_inbox,
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+ self._remove_dead_devices_from_device_inbox,
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
@@ -636,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
- async def _remove_deleted_devices_from_device_inbox(
- self, progress: JsonDict, batch_size: int
+ async def _remove_dead_devices_from_device_inbox(
+ self,
+ progress: JsonDict,
+ batch_size: int,
) -> int:
- """A background update that deletes all device_inboxes for deleted devices.
-
- This should only need to be run once (when users upgrade to v1.47.0)
+ """A background update to remove devices that were either deleted or hidden from
+ the device_inbox table.
Args:
- progress: JsonDict used to store progress of this background update
- batch_size: the maximum number of rows to retrieve in a single select query
+ progress: The update's progress dict.
+ batch_size: The batch size for this update.
Returns:
- The number of deleted rows
+ The number of rows deleted.
"""
- def _remove_deleted_devices_from_device_inbox_txn(
+ def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
- ) -> int:
- """stream_id is not unique
- we need to use an inclusive `stream_id >= ?` clause,
- since we might not have deleted all dead device messages for the stream_id
- returned from the previous query
+ ) -> Tuple[int, bool]:
- Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
- to avoid problems of deleting a large number of rows all at once
- due to a single device having lots of device messages.
- """
+ if "max_stream_id" in progress:
+ max_stream_id = progress["max_stream_id"]
+ else:
+ txn.execute("SELECT max(stream_id) FROM device_inbox")
+ # There's a type mismatch here between how we want to type the row and
+ # what fetchone says it returns, but we silence it because we know that
+ # res can't be None.
+ res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
+ if res[0] is None:
+ # this can only happen if the `device_inbox` table is empty, in which
+ # case we have no work to do.
+ return 0, True
+ else:
+ max_stream_id = res[0]
- last_stream_id = progress.get("stream_id", 0)
+ start = progress.get("stream_id", 0)
+ stop = start + batch_size
+ # delete rows in `device_inbox` which do *not* correspond to a known,
+ # unhidden device.
sql = """
- SELECT device_id, user_id, stream_id
- FROM device_inbox
+ DELETE FROM device_inbox
WHERE
- stream_id >= ?
- AND (device_id, user_id) NOT IN (
- SELECT device_id, user_id FROM devices
+ stream_id >= ? AND stream_id < ?
+ AND NOT EXISTS (
+ SELECT * FROM devices d
+ WHERE
+ d.device_id=device_inbox.device_id
+ AND d.user_id=device_inbox.user_id
+ AND NOT hidden
)
- ORDER BY stream_id
- LIMIT ?
- """
-
- txn.execute(sql, (last_stream_id, batch_size))
- rows = txn.fetchall()
+ """
- num_deleted = 0
- for row in rows:
- num_deleted += self.db_pool.simple_delete_txn(
- txn,
- "device_inbox",
- {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
- )
+ txn.execute(sql, (start, stop))
- if rows:
- # send more than stream_id to progress
- # otherwise it can happen in large deployments that
- # no change of status is visible in the log file
- # it may be that the stream_id does not change in several runs
- self.db_pool.updates._background_update_progress_txn(
- txn,
- self.REMOVE_DELETED_DEVICES,
- {
- "device_id": rows[-1][0],
- "user_id": rows[-1][1],
- "stream_id": rows[-1][2],
- },
- )
-
- return num_deleted
-
- number_deleted = await self.db_pool.runInteraction(
- "_remove_deleted_devices_from_device_inbox",
- _remove_deleted_devices_from_device_inbox_txn,
- )
-
- # The task is finished when no more lines are deleted.
- if not number_deleted:
- await self.db_pool.updates._end_background_update(
- self.REMOVE_DELETED_DEVICES
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+ {
+ "stream_id": stop,
+ "max_stream_id": max_stream_id,
+ },
)
- return number_deleted
-
- async def _remove_hidden_devices_from_device_inbox(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """A background update that deletes all device_inboxes for hidden devices.
-
- This should only need to be run once (when users upgrade to v1.47.0)
-
- Args:
- progress: JsonDict used to store progress of this background update
- batch_size: the maximum number of rows to retrieve in a single select query
-
- Returns:
- The number of deleted rows
- """
-
- def _remove_hidden_devices_from_device_inbox_txn(
- txn: LoggingTransaction,
- ) -> int:
- """stream_id is not unique
- we need to use an inclusive `stream_id >= ?` clause,
- since we might not have deleted all hidden device messages for the stream_id
- returned from the previous query
-
- Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
- to avoid problems of deleting a large number of rows all at once
- due to a single device having lots of device messages.
- """
-
- last_stream_id = progress.get("stream_id", 0)
-
- sql = """
- SELECT device_id, user_id, stream_id
- FROM device_inbox
- WHERE
- stream_id >= ?
- AND (device_id, user_id) IN (
- SELECT device_id, user_id FROM devices WHERE hidden = ?
- )
- ORDER BY stream_id
- LIMIT ?
- """
-
- txn.execute(sql, (last_stream_id, True, batch_size))
- rows = txn.fetchall()
-
- num_deleted = 0
- for row in rows:
- num_deleted += self.db_pool.simple_delete_txn(
- txn,
- "device_inbox",
- {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
- )
-
- if rows:
- # We don't just save the `stream_id` in progress as
- # otherwise it can happen in large deployments that
- # no change of status is visible in the log file, as
- # it may be that the stream_id does not change in several runs
- self.db_pool.updates._background_update_progress_txn(
- txn,
- self.REMOVE_HIDDEN_DEVICES,
- {
- "device_id": rows[-1][0],
- "user_id": rows[-1][1],
- "stream_id": rows[-1][2],
- },
- )
-
- return num_deleted
+ return stop > max_stream_id
- number_deleted = await self.db_pool.runInteraction(
- "_remove_hidden_devices_from_device_inbox",
- _remove_hidden_devices_from_device_inbox_txn,
+ finished = await self.db_pool.runInteraction(
+ "_remove_devices_from_device_inbox_txn",
+ _remove_dead_devices_from_device_inbox_txn,
)
- # The task is finished when no more lines are deleted.
- if not number_deleted:
+ if finished:
await self.db_pool.updates._end_background_update(
- self.REMOVE_HIDDEN_DEVICES
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
)
- return number_deleted
+ return batch_size
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..b06c1dc45b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
+ await self.db_pool.runInteraction(
+ "set_e2e_fallback_keys_txn",
+ self._set_e2e_fallback_keys_txn,
+ user_id,
+ device_id,
+ fallback_keys,
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ def _set_e2e_fallback_keys_txn(
+ self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
- await self.db_pool.simple_upsert(
- "e2e_fallback_keys_json",
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
- values={
- "key_id": key_id,
- "key_json": json_encoder.encode(fallback_key),
- "used": False,
- },
- desc="set_e2e_fallback_key",
+ retcol="key_json",
+ allow_none=True,
)
- await self.invalidate_cache_and_stream(
- "get_e2e_unused_fallback_key_types", (user_id, device_id)
- )
+ new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
+
+ # If the uploaded key is the same as the current fallback key,
+ # don't do anything. This prevents marking the key as unused if it
+ # was already used.
+ if old_key_json != new_key_json:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ )
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 120e4807d1..c3440de2cb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
)
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -108,16 +106,21 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
+ # Since we have been configured to write, we ought to have id generators,
+ # rather than id trackers.
+ assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+ assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1553,11 +1556,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.set(
+ (cache_entry.event.event_id,), cache_entry
+ )
txn.call_after(prefill)
@@ -1696,34 +1701,33 @@ class PersistEventsStore:
},
)
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
+ def _handle_event_relations(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
+ """Handles inserting relation data during persistence of events
Args:
- txn
- event (EventBase)
+ txn: The current database transaction.
+ event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
+ # Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- RelationTypes.THREAD,
- ):
- # Unknown relation type
+ if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
+ if not isinstance(parent_id, str):
return
- aggregation_key = relation.get("key")
+ # Annotations have a key field.
+ aggregation_key = None
+ if rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn(
txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index ae3a8a63e4..c88fd35e7f 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ # The event_thread_relation background update was replaced with the
+ # event_arbitrary_relations one, which handles any relation to avoid
+ # needed to potentially crawl the entire events table in the future.
+ self.db_pool.updates.register_noop_background_update("event_thread_relation")
+
self.db_pool.updates.register_background_update_handler(
- "event_thread_relation", self._event_thread_relation
+ "event_arbitrary_relations",
+ self._event_arbitrary_relations,
)
################################################################################
@@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
- async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
- """Background update handler which will store thread relations for existing events."""
+ async def _event_arbitrary_relations(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "")
- def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
+ # Fetch events and then filter based on whether the event has a
+ # relation or not.
txn.execute(
"""
SELECT event_id, json FROM event_json
- LEFT JOIN event_relations USING (event_id)
- WHERE event_id > ? AND event_relations.event_id IS NULL
+ WHERE event_id > ?
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
- missing_thread_relations = []
+ # (event_id, parent_id, rel_type) for each relation
+ relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
continue
- # If there's no relation (or it is not a thread), skip!
+ # If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
- if relates_to.get("rel_type") != RelationTypes.THREAD:
+
+ # If the relation type or parent event ID is not a string, skip it.
+ #
+ # Do not consider relation types that have existed for a long time,
+ # since they will already be listed in the `event_relations` table.
+ rel_type = relates_to.get("rel_type")
+ if not isinstance(rel_type, str) or rel_type in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
continue
- # Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
- missing_thread_relations.append((event_id, parent_id))
+ relations_to_insert.append((event_id, parent_id, rel_type))
+
+ # Insert the missing data, note that we upsert here in case the event
+ # has already been processed.
+ if relations_to_insert:
+ self.db_pool.simple_upsert_many_txn(
+ txn=txn,
+ table="event_relations",
+ key_names=("event_id",),
+ key_values=[(r[0],) for r in relations_to_insert],
+ value_names=("relates_to_id", "relation_type"),
+ value_values=[r[1:] for r in relations_to_insert],
+ )
- # Insert the missing data.
- self.db_pool.simple_insert_many_txn(
- txn=txn,
- table="event_relations",
- values=[
- {
- "event_id": event_id,
- "relates_to_Id": parent_id,
- "relation_type": RelationTypes.THREAD,
- }
- for event_id, parent_id in missing_thread_relations
- ],
- )
+ # Iterate the parent IDs and invalidate caches.
+ for parent_id in {r[1] for r in relations_to_insert}:
+ cache_tuple = (parent_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.get_relations_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aggregation_groups_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_thread_summary, cache_tuple
+ )
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
- txn, "event_thread_relation", {"last_event_id": latest_event_id}
+ txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
- desc="event_thread_relation", func=_event_thread_relation_txn
+ desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
)
if not num_rows:
- await self.db_pool.updates._end_background_update("event_thread_relation")
+ await self.db_pool.updates._end_background_update(
+ "event_arbitrary_relations"
+ )
return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..4cefc0a07e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1295,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1332,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1386,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: Union[
- StreamIdGenerator, SlavedIdTracker
- ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ )
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 5e55440570..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -84,28 +84,37 @@ class TokenLookupResult:
return self.user_id
-@attr.s(frozen=True, slots=True)
+@attr.s(auto_attribs=True, frozen=True, slots=True)
class RefreshTokenLookupResult:
"""Result of looking up a refresh token."""
- user_id = attr.ib(type=str)
+ user_id: str
"""The user this token belongs to."""
- device_id = attr.ib(type=str)
+ device_id: str
"""The device associated with this refresh token."""
- token_id = attr.ib(type=int)
+ token_id: int
"""The ID of this refresh token."""
- next_token_id = attr.ib(type=Optional[int])
+ next_token_id: Optional[int]
"""The ID of the refresh token which replaced this one."""
- has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+ has_next_refresh_token_been_refreshed: bool
"""True if the next refresh token was used for another refresh."""
- has_next_access_token_been_used = attr.ib(type=bool)
+ has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
+ expiry_ts: Optional[int]
+ """The time at which the refresh token expires and can not be used.
+ If None, the refresh token doesn't expire."""
+
+ ultimate_session_expiry_ts: Optional[int]
+ """The time at which the session comes to an end and can no longer be
+ refreshed.
+ If None, the session can be refreshed indefinitely."""
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@@ -1198,8 +1207,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
+ assert self._account_validity_startup_job_max_delta is not None
expiration_ts = random.randrange(
- expiration_ts - self._account_validity_startup_job_max_delta,
+ int(expiration_ts - self._account_validity_startup_job_max_delta),
expiration_ts,
)
@@ -1625,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
- (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
- at.used has_next_access_token_been_used
+ (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+ at.used AS has_next_access_token_been_used,
+ rt.expiry_ts,
+ rt.ultimate_session_expiry_ts
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1647,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
+ expiry_ts=row[6],
+ ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@@ -1728,11 +1742,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
)
self.db_pool.updates.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather
+ "users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- self.db_pool.updates.register_background_update_handler(
- "users_set_deactivated_flag", self._background_update_set_deactivated_flag
+ self.db_pool.updates.register_noop_background_update(
+ "user_threepids_grandfather"
)
self.db_pool.updates.register_background_index_update(
@@ -1805,35 +1819,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return nb_processed
- async def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.registration.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- await self.db_pool.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
- )
-
- await self.db_pool.updates._end_background_update("user_threepids_grandfather")
-
- return 1
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1943,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@@ -1950,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1965,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
+ "expiry_ts": expiry_ts,
+ "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
+ async def event_includes_relation(self, event_id: str) -> bool:
+ """Check if the given event relates to another event.
+
+ An event has a relation if it has a valid m.relates_to with a rel_type
+ and event_id in the content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$other_event_id"
+ }
+ }
+ }
+
+ Args:
+ event_id: The event to check.
+
+ Returns:
+ True if the event includes a valid relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"event_id": event_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_includes_relation",
+ )
+ return result is not None
+
+ async def event_is_target_of_relation(self, parent_id: str) -> bool:
+ """Check if the given event is the target of another event's relation.
+
+ An event is the target of an event relation if it has a valid
+ m.relates_to with a rel_type and event_id pointing to parent_id in the
+ content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$parent_id"
+ }
+ }
+ }
+
+ Args:
+ parent_id: The event to check.
+
+ Returns:
+ True if the event is the target of another event's relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"relates_to_id": parent_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_is_target_of_relation",
+ )
+ return result is not None
+
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_event_has_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_if_event_has_relations", _get_if_event_has_relations
+ "get_if_events_have_relations", _get_if_events_have_relations
)
async def has_user_annotated_event(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 17b398bb69..7d694d852d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
+ """
+ Function to retrieve user who has blocked the room.
+ user_id is non-nullable
+ It returns None if the room is not blocked.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="room_is_blocked_by",
+ )
+
async def get_rooms_paginate(
self,
start: int,
@@ -1775,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
+
+ async def unblock_room(self, room_id: str) -> None:
+ """Remove the room from blocking list.
+
+ Args:
+ room_id: Room to unblock
+ """
+ await self.db_pool.simple_delete(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ desc="unblock_room",
+ )
+ await self.db_pool.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 8b9c6adae2..e45adfcb55 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -131,24 +131,16 @@ def prepare_database(
"config==None in prepare_database, but database is not empty"
)
- # if it's a worker app, refuse to upgrade the database, to avoid multiple
- # workers doing it at once.
- if config.worker.worker_app is None:
- _upgrade_existing_database(
- cur,
- version_info,
- database_engine,
- config,
- databases=databases,
- )
- elif version_info.current_version < SCHEMA_VERSION:
- # If the DB is on an older version than we expect then we refuse
- # to start the worker (as the main process needs to run first to
- # update the schema).
- raise UpgradeDatabaseException(
- OUTDATED_SCHEMA_ON_WORKER_ERROR
- % (SCHEMA_VERSION, version_info.current_version)
- )
+ # This should be run on all processes, master or worker. The master will
+ # apply the deltas, while workers will check if any outstanding deltas
+ # exist and raise an PrepareDatabaseException if they do.
+ _upgrade_existing_database(
+ cur,
+ version_info,
+ database_engine,
+ config,
+ databases=databases,
+ )
else:
logger.info("%r: Initialising new database", databases)
@@ -358,6 +350,18 @@ def _upgrade_existing_database(
is_worker = config and config.worker.worker_app is not None
+ # If the schema version needs to be updated, and we are on a worker, we immediately
+ # know to bail out as workers cannot update the database schema. Only one process
+ # must update the database at the time, therefore we delegate this task to the master.
+ if is_worker and current_schema_state.current_version < SCHEMA_VERSION:
+ # If the DB is on an older version than we expect then we refuse
+ # to start the worker (as the main process needs to run first to
+ # update the schema).
+ raise UpgradeDatabaseException(
+ OUTDATED_SCHEMA_ON_WORKER_ERROR
+ % (SCHEMA_VERSION, current_schema_state.current_version)
+ )
+
if (
current_schema_state.compat_version is not None
and current_schema_state.compat_version > SCHEMA_VERSION
diff --git a/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..82f6408b36
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- when a device was deleted using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+-- Remove any existing instances of this job running. It's OK to stop and restart this job,
+-- as it's just deleting entries from a table - no progress will be lost.
+--
+-- This is necessary due a similar migration running the job accidentally
+-- being included in schema version 64 during v1.47.0rc1,rc2. If a
+-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64),
+-- then they would have started running this background update already.
+-- If that update was still running, then simply inserting it again would
+-- cause an SQL failure. So we effectively do an "upsert" here instead.
+
+DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6506, 'remove_deleted_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
index d60517f7b4..267b2cb539 100644
--- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql
+++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
@@ -15,4 +15,4 @@
-- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
- (6502, 'event_thread_relation', '{}');
+ (6507, 'event_arbitrary_relations', '{}');
diff --git a/synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql
index 076179123d..d79455c2ce 100644
--- a/synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql
+++ b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql
@@ -13,10 +13,6 @@
* limitations under the License.
*/
-
--- Remove messages from the device_inbox table which were orphaned
--- when a device was deleted using Synapse earlier than 1.47.0.
--- This runs as background task, but may take a bit to finish.
-
+-- Background update to clear the inboxes of hidden and deleted devices.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
- (6505, 'remove_deleted_devices_from_device_inbox', '{}');
+ (6508, 'remove_dead_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
new file mode 100644
index 0000000000..bdc491c817
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
@@ -0,0 +1,28 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE refresh_tokens
+ -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
+ -- They may not be used after they have expired.
+ -- If null, then the refresh token's lifetime is unlimited.
+ ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
+
+ALTER TABLE refresh_tokens
+ -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
+ -- No matter how much the access and refresh tokens are refreshed, they cannot
+ -- be extended past this time.
+ -- If null, then the session length is unlimited.
+ ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def get_next(self) -> AsyncContextManager[int]:
- raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+ """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+ Stream IDs are monotonically increasing or decreasing integers representing write
+ transactions. The "current" stream ID is the stream ID such that all transactions
+ with equal or smaller stream IDs have completed. Since transactions may complete out
+ of order, this is not the same as the stream ID of the last completed transaction.
+
+ Completed transactions include both committed transactions and transactions that
+ have been rolled back.
+ """
@abc.abstractmethod
- def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ def advance(self, instance_name: str, new_id: int) -> None:
+ """Advance the position of the named writer to the given ID, if greater
+ than existing entry.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+
+ Returns:
+ The maximum stream id.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to `get_current_token`.
+ """
+ raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+ """Generates stream IDs for a stream that may have multiple writers.
+
+ Each stream ID represents a write transaction, whose completion is tracked
+ so that the "current" stream ID of the stream can be determined.
+
+ See `AbstractStreamIdTracker` for more details.
+ """
+
+ @abc.abstractmethod
+ def get_next(self) -> AsyncContextManager[int]:
+ """
+ Usage:
+ async with stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ """
+ Usage:
+ async with stream_id_gen.get_next(n) as stream_ids:
+ # ... persist events ...
+ """
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
- """Used to generate new stream ids when persisting events while keeping
- track of which transactions have been completed.
+ """Generates and tracks stream IDs for a stream with a single writer.
- This allows us to get the "current" stream id, i.e. the stream id such that
- all ids less than or equal to it have completed. This handles the fact that
- persistence of events can complete out of order.
+ This class must only be used when the current Synapse process is the sole
+ writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ def advance(self, instance_name: str, new_id: int) -> None:
+ # `StreamIdGenerator` should only be used when there is a single writer,
+ # so replication should never happen.
+ raise Exception("Replication is not supported by StreamIdGenerator")
+
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
with self._lock:
self._current += self._step
next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
- """
- Usage:
- async with stream_id_gen.get_next(n) as stream_ids:
- # ... persist events ...
- """
with self._lock:
next_ids = range(
self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
-
- Returns:
- The maximum stream id.
- """
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
- """An ID generator that tracks a stream that can have multiple writers.
+ """Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
- """
- Usage:
- async with stream_id_gen.get_next_mult(5) as stream_ids:
- # ... persist events ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
-
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer."""
-
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
- """Advance the position of the named writer to the given ID, if greater
- than existing entry.
- """
-
new_id *= self._return_factor
with self._lock:
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 561b962e14..20ce294209 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -27,6 +27,7 @@ from typing import (
Generic,
Hashable,
Iterable,
+ Iterator,
Optional,
Set,
TypeVar,
@@ -40,7 +41,6 @@ from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
-from twisted.python import failure
from twisted.python.failure import Failure
from synapse.logging.context import (
@@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", [])
- def callback(r):
+ def callback(r: _T) -> _T:
object.__setattr__(self, "_result", (True, r))
# once we have set _result, no more entries will be added to _observers,
@@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
)
return r
- def errback(f):
+ def errback(f: Failure) -> Optional[Failure]:
object.__setattr__(self, "_result", (False, f))
# once we have set _result, no more entries will be added to _observers,
@@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
- f.value.__failure__ = f
+ f.value.__failure__ = f # type: ignore[union-attr]
try:
observer.errback(f)
except Exception as e:
@@ -314,7 +314,7 @@ class Linearizer:
# will release the lock.
@contextmanager
- def _ctx_manager(_):
+ def _ctx_manager(_: None) -> Iterator[None]:
try:
yield
finally:
@@ -355,7 +355,7 @@ class Linearizer:
new_defer = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1
- def cb(_r):
+ def cb(_r: None) -> "defer.Deferred[None]":
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry.count += 1
@@ -371,7 +371,7 @@ class Linearizer:
# code must be synchronous, so this is the only sensible place.)
return self._clock.sleep(0)
- def eb(e):
+ def eb(e: Failure) -> Failure:
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
@@ -435,7 +435,7 @@ class ReadWriteLock:
await make_deferred_yieldable(curr_writer)
@contextmanager
- def _ctx_manager():
+ def _ctx_manager() -> Iterator[None]:
try:
yield
finally:
@@ -464,7 +464,7 @@ class ReadWriteLock:
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
- def _ctx_manager():
+ def _ctx_manager() -> Iterator[None]:
try:
yield
finally:
@@ -524,7 +524,7 @@ def timeout_deferred(
delayed_call = reactor.callLater(timeout, time_it_out)
- def convert_cancelled(value: failure.Failure):
+ def convert_cancelled(value: Failure) -> Failure:
# if the original deferred was cancelled, and our timeout has fired, then
# the reason it was cancelled was due to our timeout. Turn the CancelledError
# into a TimeoutError.
@@ -534,7 +534,7 @@ def timeout_deferred(
deferred.addErrback(convert_cancelled)
- def cancel_timeout(result):
+ def cancel_timeout(result: _T) -> _T:
# stop the pending call to cancel the deferred if it's been fired
if delayed_call.active():
delayed_call.cancel()
@@ -542,11 +542,11 @@ def timeout_deferred(
deferred.addBoth(cancel_timeout)
- def success_cb(val):
+ def success_cb(val: _T) -> None:
if not new_d.called:
new_d.callback(val)
- def failure_cb(val):
+ def failure_cb(val: Failure) -> None:
if not new_d.called:
new_d.errback(val)
@@ -557,13 +557,13 @@ def timeout_deferred(
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value."""
- value = attr.ib(type=Any) # should be: R
+ value: Any # should be: R
- def __await__(self):
+ def __await__(self) -> Any:
return self
def __iter__(self) -> "DoneAwaitable":
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index df4d61e4b6..15debd6c46 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -17,7 +17,7 @@ import logging
import typing
from enum import Enum, auto
from sys import intern
-from typing import Callable, Dict, Optional, Sized
+from typing import Any, Callable, Dict, List, Optional, Sized
import attr
from prometheus_client.core import Gauge
@@ -58,20 +58,20 @@ class EvictionReason(Enum):
time = auto()
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class CacheMetric:
- _cache = attr.ib()
- _cache_type = attr.ib(type=str)
- _cache_name = attr.ib(type=str)
- _collect_callback = attr.ib(type=Optional[Callable])
+ _cache: Sized
+ _cache_type: str
+ _cache_name: str
+ _collect_callback: Optional[Callable]
- hits = attr.ib(default=0)
- misses = attr.ib(default=0)
+ hits: int = 0
+ misses: int = 0
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
factory=collections.Counter
)
- memory_usage = attr.ib(default=None)
+ memory_usage: Optional[int] = None
def inc_hits(self) -> None:
self.hits += 1
@@ -89,13 +89,14 @@ class CacheMetric:
self.memory_usage += memory
def dec_memory_usage(self, memory: int) -> None:
+ assert self.memory_usage is not None
self.memory_usage -= memory
def clear_memory_usage(self) -> None:
if self.memory_usage is not None:
self.memory_usage = 0
- def describe(self):
+ def describe(self) -> List[str]:
return []
def collect(self) -> None:
@@ -118,8 +119,9 @@ class CacheMetric:
self.eviction_size_by_reason[reason]
)
cache_total.labels(self._cache_name).set(self.hits + self.misses)
- if getattr(self._cache, "max_size", None):
- cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+ max_size = getattr(self._cache, "max_size", None)
+ if max_size:
+ cache_max_size.labels(self._cache_name).set(max_size)
if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted
@@ -193,7 +195,7 @@ KNOWN_KEYS = {
}
-def intern_string(string):
+def intern_string(string: Optional[str]) -> Optional[str]:
"""Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
@@ -204,7 +206,7 @@ def intern_string(string):
return string
-def intern_dict(dictionary):
+def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""Takes a dictionary and interns well known keys and their values"""
return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@@ -212,7 +214,7 @@ def intern_dict(dictionary):
}
-def _intern_known_values(key, value):
+def _intern_known_values(key: str, value: Any) -> Any:
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index da502aec11..3c4cc093af 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
- def invalidate(self, key) -> None:
+ def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index b9dcca17f1..375cd443f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,12 +19,15 @@ import logging
from typing import (
Any,
Callable,
+ Dict,
Generic,
+ Hashable,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
+ Type,
TypeVar,
Union,
cast,
@@ -32,6 +35,7 @@ from typing import (
from weakref import WeakValueDictionary
from twisted.internet import defer
+from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
- def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
+ def __init__(
+ self,
+ orig: Callable[..., Any],
+ num_args: Optional[int],
+ cache_context: bool = False,
+ ):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__(
self,
- orig,
+ orig: Callable[..., Any],
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
- def __get__(self, obj, owner):
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
max_size=self.max_entries,
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
- def _wrapped(*args, **kwargs):
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2
Args:
- num_args (int): number of positional arguments (excluding ``self`` and
+ num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
def __init__(
self,
- orig,
- max_entries=1000,
- num_args=None,
- tree=False,
- cache_context=False,
- iterable=False,
+ orig: Callable[..., Any],
+ max_entries: int = 1000,
+ num_args: Optional[int] = None,
+ tree: bool = False,
+ cache_context: bool = False,
+ iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
- def __get__(self, obj, owner):
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
- def _wrapped(*args, **kwargs):
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=None):
+ def __init__(
+ self,
+ orig: Callable[..., Any],
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ ):
"""
Args:
- orig (function)
- cached_method_name (str): The name of the cached method.
- list_name (str): Name of the argument which is the bulk lookup list
- num_args (int): number of positional arguments (excluding ``self``,
+ orig
+ cached_method_name: The name of the cached method.
+ list_name: Name of the argument which is the bulk lookup list
+ num_args: number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name)
)
- def __get__(self, obj, objtype=None):
+ def __get__(
+ self, obj: Optional[Any], objtype: Optional[Type] = None
+ ) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {}
- def update_results_dict(res, arg):
+ def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used.
if num_args == 1:
- def arg_to_cache_key(arg):
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
else:
keylist = list(keyargs)
- def arg_to_cache_key(arg):
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
keylist[self.list_pos] = arg
return tuple(keylist)
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
- def complete_all(res):
+ def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val)
results[e] = val
- def errback(f):
+ def errback(f: Failure) -> Failure:
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c3f72aa06d..67ee4c693b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr
from typing_extensions import Literal
+from twisted.internet import defer
+
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
@@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
# Don't bother starting the loop if things never expire
return
- def f():
+ def f() -> "defer.Deferred[None]":
return run_as_background_process(
"prune_cache_%s" % self._cache_name, self._prune_cache
)
@@ -157,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
self[key] = value
return value
- def _prune_cache(self) -> None:
+ async def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
@@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
return False
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _CacheEntry:
- time = attr.ib(type=int)
- value = attr.ib()
+ time: int
+ value: Any
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 31097d6439..91837655f8 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -18,12 +18,13 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
-def user_left_room(distributor, user, room_id):
+def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
distributor.fire("user_left_room", user=user, room_id=room_id)
@@ -63,7 +64,7 @@ class Distributor:
self.pre_registration[name] = []
self.pre_registration[name].append(observer)
- def fire(self, name: str, *args, **kwargs) -> None:
+ def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
"""Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred.
@@ -95,7 +96,7 @@ class Signal:
Each observer callable may return a Deferred."""
self.observers.append(observer)
- def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
+ def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.
@@ -103,7 +104,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
- async def do(observer):
+ async def do(observer: Callable[..., Any]) -> Any:
try:
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
@@ -120,5 +121,5 @@ class Signal:
defer.gatherResults(deferreds, consumeErrors=True)
)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index a447ce4e55..214eb17fbc 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -3,23 +3,52 @@
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
# private class.
-
from socket import (
AF_INET,
AF_INET6,
AF_UNSPEC,
SOCK_DGRAM,
SOCK_STREAM,
+ AddressFamily,
+ SocketKind,
gaierror,
getaddrinfo,
)
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ List,
+ NoReturn,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.interfaces import (
+ IAddress,
+ IHostnameResolver,
+ IHostResolution,
+ IReactorThreads,
+ IResolutionReceiver,
+)
from twisted.internet.threads import deferToThreadPool
+if TYPE_CHECKING:
+ # The types below are copied from
+ # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
+ # so that the type hints can match the interfaces.
+ from twisted.python.runtime import platform
+
+ if platform.supportsThreads():
+ from twisted.python.threadpool import ThreadPool
+ else:
+ ThreadPool = object # type: ignore[misc, assignment]
+
@implementer(IHostResolution)
class HostResolution:
@@ -27,13 +56,13 @@ class HostResolution:
The in-progress resolution of a given hostname.
"""
- def __init__(self, name):
+ def __init__(self, name: str):
"""
Create a L{HostResolution} with the given name.
"""
self.name = name
- def cancel(self):
+ def cancel(self) -> NoReturn:
# IHostResolution.cancel
raise NotImplementedError()
@@ -62,6 +91,17 @@ _socktypeToType = {
}
+_GETADDRINFO_RESULT = List[
+ Tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ Union[Tuple[str, int], Tuple[str, int, int, int]],
+ ]
+]
+
+
@implementer(IHostnameResolver)
class GAIResolver:
"""
@@ -69,7 +109,12 @@ class GAIResolver:
L{getaddrinfo} in a thread.
"""
- def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+ def __init__(
+ self,
+ reactor: IReactorThreads,
+ getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
+ getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
+ ):
"""
Create a L{GAIResolver}.
@param reactor: the reactor to schedule result-delivery on
@@ -89,14 +134,16 @@ class GAIResolver:
)
self._getaddrinfo = getaddrinfo
- def resolveHostName(
+ # The types on IHostnameResolver is incorrect in Twisted, see
+ # https://twistedmatrix.com/trac/ticket/10276
+ def resolveHostName( # type: ignore[override]
self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics="TCP",
- ):
+ resolutionReceiver: IResolutionReceiver,
+ hostName: str,
+ portNumber: int = 0,
+ addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+ transportSemantics: str = "TCP",
+ ) -> IHostResolution:
"""
See L{IHostnameResolver.resolveHostName}
@param resolutionReceiver: see interface
@@ -112,7 +159,7 @@ class GAIResolver:
]
socketType = _transportToSocket[transportSemantics]
- def get():
+ def get() -> _GETADDRINFO_RESULT:
try:
return self._getaddrinfo(
hostName, portNumber, addressFamily, socketType
@@ -125,7 +172,7 @@ class GAIResolver:
resolutionReceiver.resolutionBegan(resolution)
@d.addCallback
- def deliverResults(result):
+ def deliverResults(result: _GETADDRINFO_RESULT) -> None:
for family, socktype, _proto, _cannoname, sockaddr in result:
addrType = _afToType[family]
resolutionReceiver.addressResolved(
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1e784b3f1f..98ee49af6e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -56,14 +56,22 @@ block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
)
+
+# This is dynamically created in InFlightGauge.__init__.
+class _InFlightMetric(Protocol):
+ real_time_max: float
+ real_time_sum: float
+
+
# Tracks the number of blocks currently active
-in_flight = InFlightGauge(
+in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
"synapse_util_metrics_block_in_flight",
"",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
+
T = TypeVar("T", bound=Callable[..., Any])
@@ -180,7 +188,7 @@ class Measure:
"""
return self._logging_context.get_resource_usage()
- def _update_in_flight(self, metrics) -> None:
+ def _update_in_flight(self, metrics: _InFlightMetric) -> None:
"""Gets called when processing in flight metrics"""
assert self.start is not None
duration = self.clock.time() - self.start
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f029432191..ea1032b4fc 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -19,6 +19,8 @@ import string
from collections.abc import Iterable
from typing import Optional, Tuple
+from netaddr import valid_ipv6
+
from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -97,7 +99,10 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
raise ValueError("Invalid server name '%s'" % server_name)
-VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+# An approximation of the domain name syntax in RFC 1035, section 2.3.1.
+# NB: "\Z" is not equivalent to "$".
+# The latter will match the position before a "\n" at the end of a string.
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
@@ -122,13 +127,15 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
if host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
- return host, port
- # otherwise it should only be alphanumerics.
- if not VALID_HOST_REGEX.match(host):
- raise ValueError(
- "Server name '%s' contains invalid characters" % (server_name,)
- )
+ # valid_ipv6 raises when given an empty string
+ ipv6_address = host[1:-1]
+ if not ipv6_address or not valid_ipv6(ipv6_address):
+ raise ValueError(
+ "Server name '%s' is not a valid IPv6 address" % (server_name,)
+ )
+ elif not VALID_HOST_REGEX.match(host):
+ raise ValueError("Server name '%s' has an invalid format" % (server_name,))
return host, port
|