diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6766ef445c..ee3313a41c 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,6 +17,7 @@
""" This is a reference implementation of a Matrix home server.
"""
+import os
import sys
# Check that we're not running on an unsupported Python version.
@@ -35,4 +36,11 @@ try:
except ImportError:
pass
-__version__ = "1.3.1"
+__version__ = "1.4.1"
+
+if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
+ # We import here so that we don't have to install a bunch of deps when
+ # running the packaging tox test.
+ from synapse.util.patch_inline_callbacks import do_patch
+
+ do_patch()
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index ddc195bc32..cb50579fd2 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, JoinRules, Membership, UserTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -179,7 +179,6 @@ class Auth(object):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)
- @opentracing.trace
@defer.inlineCallbacks
def get_user_by_req(
self, request, allow_guest=False, rights="access", allow_expired=False
@@ -212,6 +211,7 @@ class Auth(object):
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
+ opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips:
yield self.store.insert_client_ip(
@@ -263,6 +263,8 @@ class Auth(object):
request.authenticated_entity = user.to_string()
opentracing.set_tag("authenticated_entity", user.to_string())
+ if device_id:
+ opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user, token_id, is_guest, device_id, app_service=app_service
@@ -709,7 +711,7 @@ class Auth(object):
)
@defer.inlineCallbacks
- def check_auth_blocking(self, user_id=None, threepid=None):
+ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
@@ -722,6 +724,9 @@ class Auth(object):
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
+
+ user_type(str|None): If present, is used to decide whether to check against
+ certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
@@ -759,6 +764,10 @@ class Auth(object):
self.hs.config.mau_limits_reserved_threepids, threepid
):
return
+ elif user_type == UserTypes.SUPPORT:
+ # If the user does not exist yet and is of type "support",
+ # allow registration. Support users are excluded from MAU checks.
+ return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f29bce560c..60e99e4663 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -97,8 +97,6 @@ class EventTypes(object):
class RejectedReason(object):
AUTH_ERROR = "auth_error"
- REPLACED = "replaced"
- NOT_ANCESTOR = "not_ancestor"
class RoomCreationPreset(object):
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index a18d31db0c..cca92c34ba 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,6 +17,7 @@
"""Contains exceptions and error codes."""
import logging
+from typing import Dict
from six import iteritems
from six.moves import http_client
@@ -112,7 +113,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
- self._additional_fields = {}
+ self._additional_fields = {} # type: Dict
else:
self._additional_fields = dict(additional_fields)
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 95292b7dec..c6f50fd7b9 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import Dict
+
import attr
@@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
)
-} # type: dict[str, RoomVersion]
+} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index c30fdeee9a..2ac7d5c064 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs)
# Start the tracer
- synapse.logging.opentracing.init_tracer(hs.config)
+ synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
+ hs.config
+ )
# It is now safe to start your Synapse.
hs.start_listening(listeners)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 04f1ed14f3..eb54f56853 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -561,10 +561,12 @@ def run(hs):
stats["database_engine"] = hs.get_datastore().database_engine_name
stats["database_server_version"] = hs.get_datastore().get_server_version()
- logger.info("Reporting stats to matrix.org: %s" % (stats,))
+ logger.info(
+ "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
+ )
try:
yield hs.get_simple_http_client().put_json(
- "https://matrix.org/report-usage-stats/push", stats
+ hs.config.report_stats_endpoint, stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
@@ -603,13 +605,13 @@ def run(hs):
@defer.inlineCallbacks
def generate_monthly_active_users():
current_mau_count = 0
- reserved_count = 0
+ reserved_users = ()
store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count()
- reserved_count = yield store.get_registered_reserved_users_count()
+ reserved_users = yield store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
- registered_reserved_users_mau_gauge.set(float(reserved_count))
+ registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.max_mau_value))
def start_generate_monthly_active_users():
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 31f6530978..08619404bb 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,7 +18,9 @@
import argparse
import errno
import os
+from collections import OrderedDict
from textwrap import dedent
+from typing import Any, MutableMapping, Optional
from six import integer_types
@@ -51,7 +53,56 @@ Missing mandatory `server_name` config option.
"""
+def path_exists(file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+
class Config(object):
+ """
+ A configuration section, containing configuration keys and values.
+
+ Attributes:
+ section (str): The section title of this config object, such as
+ "tls" or "logger". This is used to refer to it on the root
+ logger (for example, `config.tls.some_option`). Must be
+ defined in subclasses.
+ """
+
+ section = None
+
+ def __init__(self, root_config=None):
+ self.root = root_config
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Try and fetch a configuration option that does not exist on this class.
+
+ This is so that existing configs that rely on `self.value`, where value
+ is actually from a different config section, continue to work.
+ """
+ if item in ["generate_config_section", "read_config"]:
+ raise AttributeError(item)
+
+ if self.root is None:
+ raise AttributeError(item)
+ else:
+ return self.root._get_unclassed_config(self.section, item)
+
@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
@@ -88,22 +139,7 @@ class Config(object):
@classmethod
def path_exists(cls, file_path):
- """Check if a file exists
-
- Unlike os.path.exists, this throws an exception if there is an error
- checking if the file exists (for example, if there is a perms error on
- the parent dir).
-
- Returns:
- bool: True if the file exists; False if not.
- """
- try:
- os.stat(file_path)
- return True
- except OSError as e:
- if e.errno != errno.ENOENT:
- raise e
- return False
+ return path_exists(file_path)
@classmethod
def check_file(cls, file_path, config_name):
@@ -136,42 +172,106 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
- def invoke_all(self, name, *args, **kargs):
- """Invoke all instance methods with the given name and arguments in the
- class's MRO.
+
+class RootConfig(object):
+ """
+ Holder of an application's configuration.
+
+ What configuration this object holds is defined by `config_classes`, a list
+ of Config classes that will be instantiated and given the contents of a
+ configuration file to read. They can then be accessed on this class by their
+ section name, defined in the Config or dynamically set to be the name of the
+ class, lower-cased and with "Config" removed.
+ """
+
+ config_classes = []
+
+ def __init__(self):
+ self._configs = OrderedDict()
+
+ for config_class in self.config_classes:
+ if config_class.section is None:
+ raise ValueError("%r requires a section name" % (config_class,))
+
+ try:
+ conf = config_class(self)
+ except Exception as e:
+ raise Exception("Failed making %s: %r" % (config_class.section, e))
+ self._configs[config_class.section] = conf
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Redirect lookups on this object either to config objects, or values on
+ config objects, so that `config.tls.blah` works, as well as legacy uses
+ of things like `config.server_name`. It will first look up the config
+ section name, and then values on those config classes.
+ """
+ if item in self._configs.keys():
+ return self._configs[item]
+
+ return self._get_unclassed_config(None, item)
+
+ def _get_unclassed_config(self, asking_section: Optional[str], item: str):
+ """
+ Fetch a config value from one of the instantiated config classes that
+ has not been fetched directly.
+
+ Args:
+ asking_section: If this check is coming from a Config child, which
+ one? This section will not be asked if it has the value.
+ item: The configuration value key.
+
+ Raises:
+ AttributeError if no config classes have the config key. The body
+ will contain what sections were checked.
+ """
+ for key, val in self._configs.items():
+ if key == asking_section:
+ continue
+
+ if item in dir(val):
+ return getattr(val, item)
+
+ raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+
+ def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+ """
+ Invoke a function on all instantiated config objects this RootConfig is
+ configured to use.
Args:
- name (str): Name of function to invoke
+ func_name: Name of function to invoke
*args
**kwargs
-
Returns:
- list: The list of the return values from each method called
+ ordered dictionary of config section name and the result of the
+ function from it.
"""
- results = []
- for cls in type(self).mro():
- if name in cls.__dict__:
- results.append(getattr(cls, name)(self, *args, **kargs))
- return results
+ res = OrderedDict()
+
+ for name, config in self._configs.items():
+ if hasattr(config, func_name):
+ res[name] = getattr(config, func_name)(*args, **kwargs)
+
+ return res
@classmethod
- def invoke_all_static(cls, name, *args, **kargs):
- """Invoke all static methods with the given name and arguments in the
- class's MRO.
+ def invoke_all_static(cls, func_name: str, *args, **kwargs):
+ """
+ Invoke a static function on config objects this RootConfig is
+ configured to use.
Args:
- name (str): Name of function to invoke
+ func_name: Name of function to invoke
*args
**kwargs
-
Returns:
- list: The list of the return values from each method called
+ ordered dictionary of config section name and the result of the
+ function from it.
"""
- results = []
- for c in cls.mro():
- if name in c.__dict__:
- results.append(getattr(c, name)(*args, **kargs))
- return results
+ for config in cls.config_classes:
+ if hasattr(config, func_name):
+ getattr(config, func_name)(*args, **kwargs)
def generate_config(
self,
@@ -187,7 +287,8 @@ class Config(object):
tls_private_key_path=None,
acme_domain=None,
):
- """Build a default configuration file
+ """
+ Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
@@ -242,6 +343,7 @@ class Config(object):
Returns:
str: the yaml config file
"""
+
return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
@@ -257,7 +359,7 @@ class Config(object):
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
- )
+ ).values()
)
@classmethod
@@ -444,7 +546,7 @@ class Config(object):
)
(config_path,) = config_files
- if not cls.path_exists(config_path):
+ if not path_exists(config_path):
print("Generating config file %s" % (config_path,))
if config_args.data_directory:
@@ -469,7 +571,7 @@ class Config(object):
open_private_ports=config_args.open_private_ports,
)
- if not cls.path_exists(config_dir_path):
+ if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_file.write("# vim:ft=yaml\n\n")
@@ -518,7 +620,7 @@ class Config(object):
return obj
- def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+ def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object.
Args:
@@ -607,3 +709,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files
+
+
+__all__ = ["Config", "RootConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
new file mode 100644
index 0000000000..86bc965ee4
--- /dev/null
+++ b/synapse/config/_base.pyi
@@ -0,0 +1,135 @@
+from typing import Any, List, Optional
+
+from synapse.config import (
+ api,
+ appservice,
+ captcha,
+ cas,
+ consent_config,
+ database,
+ emailconfig,
+ groups,
+ jwt_config,
+ key,
+ logger,
+ metrics,
+ password,
+ password_auth_providers,
+ push,
+ ratelimiting,
+ registration,
+ repository,
+ room_directory,
+ saml2_config,
+ server,
+ server_notices_config,
+ spam_checker,
+ stats,
+ third_party_event_rules,
+ tls,
+ tracer,
+ user_directory,
+ voip,
+ workers,
+)
+
+class ConfigError(Exception): ...
+
+MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
+MISSING_REPORT_STATS_SPIEL: str
+MISSING_SERVER_NAME: str
+
+def path_exists(file_path: str): ...
+
+class RootConfig:
+ server: server.ServerConfig
+ tls: tls.TlsConfig
+ database: database.DatabaseConfig
+ logging: logger.LoggingConfig
+ ratelimit: ratelimiting.RatelimitConfig
+ media: repository.ContentRepositoryConfig
+ captcha: captcha.CaptchaConfig
+ voip: voip.VoipConfig
+ registration: registration.RegistrationConfig
+ metrics: metrics.MetricsConfig
+ api: api.ApiConfig
+ appservice: appservice.AppServiceConfig
+ key: key.KeyConfig
+ saml2: saml2_config.SAML2Config
+ cas: cas.CasConfig
+ jwt: jwt_config.JWTConfig
+ password: password.PasswordConfig
+ email: emailconfig.EmailConfig
+ worker: workers.WorkerConfig
+ authproviders: password_auth_providers.PasswordAuthProviderConfig
+ push: push.PushConfig
+ spamchecker: spam_checker.SpamCheckerConfig
+ groups: groups.GroupsConfig
+ userdirectory: user_directory.UserDirectoryConfig
+ consent: consent_config.ConsentConfig
+ stats: stats.StatsConfig
+ servernotices: server_notices_config.ServerNoticesConfig
+ roomdirectory: room_directory.RoomDirectoryConfig
+ thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
+ tracer: tracer.TracerConfig
+
+ config_classes: List = ...
+ def __init__(self) -> None: ...
+ def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
+ @classmethod
+ def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
+ def __getattr__(self, item: str): ...
+ def parse_config_dict(
+ self,
+ config_dict: Any,
+ config_dir_path: Optional[Any] = ...,
+ data_dir_path: Optional[Any] = ...,
+ ) -> None: ...
+ read_config: Any = ...
+ def generate_config(
+ self,
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ generate_secrets: bool = ...,
+ report_stats: Optional[str] = ...,
+ open_private_ports: bool = ...,
+ listeners: Optional[Any] = ...,
+ database_conf: Optional[Any] = ...,
+ tls_certificate_path: Optional[str] = ...,
+ tls_private_key_path: Optional[str] = ...,
+ acme_domain: Optional[str] = ...,
+ ): ...
+ @classmethod
+ def load_or_generate_config(cls, description: Any, argv: Any): ...
+ @classmethod
+ def load_config(cls, description: Any, argv: Any): ...
+ @classmethod
+ def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
+ @classmethod
+ def load_config_with_parser(cls, parser: Any, argv: Any): ...
+ def generate_missing_files(
+ self, config_dict: dict, config_dir_path: str
+ ) -> None: ...
+
+class Config:
+ root: RootConfig
+ def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
+ def __getattr__(self, item: str, from_root: bool = ...): ...
+ @staticmethod
+ def parse_size(value: Any): ...
+ @staticmethod
+ def parse_duration(value: Any): ...
+ @staticmethod
+ def abspath(file_path: Optional[str]): ...
+ @classmethod
+ def path_exists(cls, file_path: str): ...
+ @classmethod
+ def check_file(cls, file_path: str, config_name: str): ...
+ @classmethod
+ def ensure_directory(cls, dir_path: str): ...
+ @classmethod
+ def read_file(cls, file_path: str, config_name: str): ...
+
+def read_config_files(config_files: List[str]): ...
+def find_config_files(search_paths: List[str]): ...
diff --git a/synapse/config/api.py b/synapse/config/api.py
index dddea79a8a..74cd53a8ed 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -18,6 +18,8 @@ from ._base import Config
class ApiConfig(Config):
+ section = "api"
+
def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get(
"room_invite_state_types",
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 8387ff6805..9b4682222d 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from typing import Dict
from six import string_types
from six.moves.urllib import parse as urlparse
@@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
+ section = "appservice"
+
def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
@@ -56,8 +59,8 @@ def load_appservices(hostname, config_files):
return []
# Dicts of value -> filename
- seen_as_tokens = {}
- seen_ids = {}
+ seen_as_tokens = {} # type: Dict[str, str]
+ seen_ids = {} # type: Dict[str, str]
appservices = []
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 8dac8152cf..44bd5c6799 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -16,6 +16,8 @@ from ._base import Config
class CaptchaConfig(Config):
+ section = "captcha"
+
def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index ebe34d933b..4526c1a67b 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -22,17 +22,21 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
+ section = "cas"
+
def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
+ self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes", {})
else:
self.cas_enabled = False
self.cas_server_url = None
self.cas_service_url = None
+ self.cas_displayname_attribute = None
self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs):
@@ -43,6 +47,7 @@ class CasConfig(Config):
# enabled: true
# server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448"
+ # #displayname_attribute: name
# #required_attributes:
# # name: value
"""
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 94916f3a49..62c4c44d60 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -73,8 +73,11 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config):
- def __init__(self):
- super(ConsentConfig, self).__init__()
+
+ section = "consent"
+
+ def __init__(self, *args):
+ super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None
self.user_consent_template_dir = None
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 118aafbd4a..0e2509f0b1 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -21,6 +21,8 @@ from ._base import Config
class DatabaseConfig(Config):
+ section = "database"
+
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index e5de768b0c..658897a77e 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -28,6 +28,8 @@ from ._base import Config, ConfigError
class EmailConfig(Config):
+ section = "email"
+
def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.
@@ -169,12 +171,22 @@ class EmailConfig(Config):
self.email_registration_template_text = email_config.get(
"registration_template_text", "registration.txt"
)
+ self.email_add_threepid_template_html = email_config.get(
+ "add_threepid_template_html", "add_threepid.html"
+ )
+ self.email_add_threepid_template_text = email_config.get(
+ "add_threepid_template_text", "add_threepid.txt"
+ )
+
self.email_password_reset_template_failure_html = email_config.get(
"password_reset_template_failure_html", "password_reset_failure.html"
)
self.email_registration_template_failure_html = email_config.get(
"registration_template_failure_html", "registration_failure.html"
)
+ self.email_add_threepid_template_failure_html = email_config.get(
+ "add_threepid_template_failure_html", "add_threepid_failure.html"
+ )
# These templates do not support any placeholder variables, so we
# will read them from disk once during setup
@@ -184,6 +196,9 @@ class EmailConfig(Config):
email_registration_template_success_html = email_config.get(
"registration_template_success_html", "registration_success.html"
)
+ email_add_threepid_template_success_html = email_config.get(
+ "add_threepid_template_success_html", "add_threepid_success.html"
+ )
# Check templates exist
for f in [
@@ -191,9 +206,14 @@ class EmailConfig(Config):
self.email_password_reset_template_text,
self.email_registration_template_html,
self.email_registration_template_text,
+ self.email_add_threepid_template_html,
+ self.email_add_threepid_template_text,
self.email_password_reset_template_failure_html,
+ self.email_registration_template_failure_html,
+ self.email_add_threepid_template_failure_html,
email_password_reset_template_success_html,
email_registration_template_success_html,
+ email_add_threepid_template_success_html,
]:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
@@ -212,6 +232,12 @@ class EmailConfig(Config):
self.email_registration_template_success_html_content = self.read_file(
filepath, "email.registration_template_success_html"
)
+ filepath = os.path.join(
+ self.email_template_dir, email_add_threepid_template_success_html
+ )
+ self.email_add_threepid_template_success_html_content = self.read_file(
+ filepath, "email.add_threepid_template_success_html"
+ )
if self.email_enable_notifs:
required = [
@@ -328,6 +354,12 @@ class EmailConfig(Config):
# #registration_template_html: registration.html
# #registration_template_text: registration.txt
#
+ # # Templates for validation emails sent by the homeserver when adding an email to
+ # # your user account
+ # #
+ # #add_threepid_template_html: add_threepid.html
+ # #add_threepid_template_text: add_threepid.txt
+ #
# # Templates for password reset success and failure pages that a user
# # will see after attempting to reset their password
# #
@@ -339,6 +371,12 @@ class EmailConfig(Config):
# #
# #registration_template_success_html: registration_success.html
# #registration_template_failure_html: registration_failure.html
+ #
+ # # Templates for success and failure pages that a user will see after attempting
+ # # to add an email or phone to their account
+ # #
+ # #add_threepid_success_html: add_threepid_success.html
+ # #add_threepid_failure_html: add_threepid_failure.html
"""
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index 2a522b5f44..d6862d9a64 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -17,6 +17,8 @@ from ._base import Config
class GroupsConfig(Config):
+ section = "groups"
+
def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72acad4f18..6e348671c7 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -46,36 +47,37 @@ from .voip import VoipConfig
from .workers import WorkerConfig
-class HomeServerConfig(
- ServerConfig,
- TlsConfig,
- DatabaseConfig,
- LoggingConfig,
- RatelimitConfig,
- ContentRepositoryConfig,
- CaptchaConfig,
- VoipConfig,
- RegistrationConfig,
- MetricsConfig,
- ApiConfig,
- AppServiceConfig,
- KeyConfig,
- SAML2Config,
- CasConfig,
- JWTConfig,
- PasswordConfig,
- EmailConfig,
- WorkerConfig,
- PasswordAuthProviderConfig,
- PushConfig,
- SpamCheckerConfig,
- GroupsConfig,
- UserDirectoryConfig,
- ConsentConfig,
- StatsConfig,
- ServerNoticesConfig,
- RoomDirectoryConfig,
- ThirdPartyRulesConfig,
- TracerConfig,
-):
- pass
+class HomeServerConfig(RootConfig):
+
+ config_classes = [
+ ServerConfig,
+ TlsConfig,
+ DatabaseConfig,
+ LoggingConfig,
+ RatelimitConfig,
+ ContentRepositoryConfig,
+ CaptchaConfig,
+ VoipConfig,
+ RegistrationConfig,
+ MetricsConfig,
+ ApiConfig,
+ AppServiceConfig,
+ KeyConfig,
+ SAML2Config,
+ CasConfig,
+ JWTConfig,
+ PasswordConfig,
+ EmailConfig,
+ WorkerConfig,
+ PasswordAuthProviderConfig,
+ PushConfig,
+ SpamCheckerConfig,
+ GroupsConfig,
+ UserDirectoryConfig,
+ ConsentConfig,
+ StatsConfig,
+ ServerNoticesConfig,
+ RoomDirectoryConfig,
+ ThirdPartyRulesConfig,
+ TracerConfig,
+ ]
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index 36d87cef03..a568726985 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config):
+ section = "jwt"
+
def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ba2199bceb..ec5d430afb 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -50,6 +50,33 @@ and you should enable 'federation_verify_certificates' in your configuration.
If you are *sure* you want to do this, set 'accept_keys_insecurely' on the
trusted_key_server configuration."""
+TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN = """\
+Synapse requires that a list of trusted key servers are specified in order to
+provide signing keys for other servers in the federation.
+
+This homeserver does not have a trusted key server configured in
+homeserver.yaml and will fall back to the default of 'matrix.org'.
+
+Trusted key servers should be long-lived and stable which makes matrix.org a
+good choice for many admins, but some admins may wish to choose another. To
+suppress this warning, the admin should set 'trusted_key_servers' in
+homeserver.yaml to their desired key server and 'suppress_key_server_warning'
+to 'true'.
+
+In a future release the software-defined default will be removed entirely and
+the trusted key server will be defined exclusively by the value of
+'trusted_key_servers'.
+--------------------------------------------------------------------------------"""
+
+TRUSTED_KEY_SERVER_CONFIGURED_AS_M_ORG_WARN = """\
+This server is configured to use 'matrix.org' as its trusted key server via the
+'trusted_key_servers' config option. 'matrix.org' is a good choice for a key
+server since it is long-lived, stable and trusted. However, some admins may
+wish to use another server for this purpose.
+
+To suppress this warning and continue using 'matrix.org', admins should set
+'suppress_key_server_warning' to 'true' in homeserver.yaml.
+--------------------------------------------------------------------------------"""
logger = logging.getLogger(__name__)
@@ -65,6 +92,8 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
+ section = "key"
+
def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
@@ -85,6 +114,7 @@ class KeyConfig(Config):
config.get("key_refresh_interval", "1d")
)
+ suppress_key_server_warning = config.get("suppress_key_server_warning", False)
key_server_signing_keys_path = config.get("key_server_signing_keys_path")
if key_server_signing_keys_path:
self.key_server_signing_keys = self.read_signing_keys(
@@ -95,6 +125,7 @@ class KeyConfig(Config):
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
+ logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
key_servers = [{"server_name": "matrix.org"}]
else:
key_servers = config.get("trusted_key_servers", [])
@@ -108,6 +139,11 @@ class KeyConfig(Config):
# merge the 'perspectives' config into the 'trusted_key_servers' config.
key_servers.extend(_perspectives_to_key_servers(config))
+ if not suppress_key_server_warning and "matrix.org" in (
+ s["server_name"] for s in key_servers
+ ):
+ logger.warning(TRUSTED_KEY_SERVER_CONFIGURED_AS_M_ORG_WARN)
+
# list of TrustedKeyServer objects
self.key_servers = list(
_parse_key_servers(key_servers, self.federation_verify_certificates)
@@ -190,6 +226,10 @@ class KeyConfig(Config):
# This setting supercedes an older setting named `perspectives`. The old format
# is still supported for backwards-compatibility, but it is deprecated.
#
+ # 'trusted_key_servers' defaults to matrix.org, but using it will generate a
+ # warning on start-up. To suppress this warning, set
+ # 'suppress_key_server_warning' to true.
+ #
# Options for each entry in the list include:
#
# server_name: the name of the server. required.
@@ -214,11 +254,13 @@ class KeyConfig(Config):
# "ed25519:auto": "abcdefghijklmnopqrstuvwxyzabcdefghijklmopqr"
# - server_name: "my_other_trusted_server.example.com"
#
- # The default configuration is:
- #
- #trusted_key_servers:
- # - server_name: "matrix.org"
+ trusted_key_servers:
+ - server_name: "matrix.org"
+
+ # Uncomment the following to disable the warning that is emitted when the
+ # trusted_key_servers include 'matrix.org'. See above.
#
+ #suppress_key_server_warning: true
# The signing keys to use when acting as a trusted key server. If not specified
# defaults to the server signing key.
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 2704c18720..be92e33f93 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,7 +21,12 @@ from string import Template
import yaml
-from twisted.logger import STDLibLogObserver, globalLogBeginner
+from twisted.logger import (
+ ILogObserver,
+ LogBeginner,
+ STDLibLogObserver,
+ globalLogBeginner,
+)
import synapse
from synapse.app import _base as appbase
@@ -63,9 +68,6 @@ handlers:
filters: [context]
loggers:
- synapse:
- level: INFO
-
synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
@@ -74,11 +76,15 @@ loggers:
root:
level: INFO
handlers: [file, console]
+
+disable_existing_loggers: false
"""
)
class LoggingConfig(Config):
+ section = "logging"
+
def read_config(self, config, **kwargs):
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
@@ -124,7 +130,7 @@ class LoggingConfig(Config):
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
-def _setup_stdlib_logging(config, log_config):
+def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
"""
Set up Python stdlib logging.
"""
@@ -165,12 +171,12 @@ def _setup_stdlib_logging(config, log_config):
return observer(event)
- globalLogBeginner.beginLoggingTo(
- [_log], redirectStandardIO=not config.no_redirect_stdio
- )
+ logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")
+ return observer
+
def _reload_stdlib_logging(*args, log_config=None):
logger = logging.getLogger("")
@@ -181,7 +187,9 @@ def _reload_stdlib_logging(*args, log_config=None):
logging.config.dictConfig(log_config)
-def setup_logging(hs, config, use_worker_options=False):
+def setup_logging(
+ hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
+) -> ILogObserver:
"""
Set up the logging subsystem.
@@ -191,6 +199,12 @@ def setup_logging(hs, config, use_worker_options=False):
use_worker_options (bool): True to use the 'worker_log_config' option
instead of 'log_config'.
+
+ logBeginner: The Twisted logBeginner to use.
+
+ Returns:
+ The "root" Twisted Logger observer, suitable for sending logs to from a
+ Logger instance.
"""
log_config = config.worker_log_config if use_worker_options else config.log_config
@@ -210,10 +224,12 @@ def setup_logging(hs, config, use_worker_options=False):
log_config_body = read_config()
if log_config_body and log_config_body.get("structured") is True:
- setup_structured_logging(hs, config, log_config_body)
+ logger = setup_structured_logging(
+ hs, config, log_config_body, logBeginner=logBeginner
+ )
appbase.register_sighup(read_config, callback=reload_structured_logging)
else:
- _setup_stdlib_logging(config, log_config_body)
+ logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
# make sure that the first thing we log is a thing we can grep backwards
@@ -221,3 +237,5 @@ def setup_logging(hs, config, use_worker_options=False):
logging.warn("***** STARTING SERVER *****")
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
+
+ return logger
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 653b990e67..282a43bddb 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -16,11 +16,9 @@
import attr
-from ._base import Config, ConfigError
+from synapse.python_dependencies import DependencyException, check_requirements
-MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry
- integration.
- """
+from ._base import Config, ConfigError
@attr.s
@@ -36,9 +34,14 @@ class MetricsFlags(object):
class MetricsConfig(Config):
+ section = "metrics"
+
def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
+ self.report_stats_endpoint = config.get(
+ "report_stats_endpoint", "https://matrix.org/report-usage-stats/push"
+ )
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
@@ -51,9 +54,9 @@ class MetricsConfig(Config):
self.sentry_enabled = "sentry" in config
if self.sentry_enabled:
try:
- import sentry_sdk # noqa F401
- except ImportError:
- raise ConfigError(MISSING_SENTRY)
+ check_requirements("sentry")
+ except DependencyException as e:
+ raise ConfigError(e.message)
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
@@ -97,4 +100,10 @@ class MetricsConfig(Config):
else:
res += "report_stats: %s\n" % ("true" if report_stats else "false")
+ res += """
+ # The endpoint to report the anonymized homeserver usage statistics to.
+ # Defaults to https://matrix.org/report-usage-stats/push
+ #
+ #report_stats_endpoint: https://example.com/report-usage-stats/push
+ """
return res
diff --git a/synapse/config/password.py b/synapse/config/password.py
index d5b5953f2f..2a634ac751 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -20,6 +20,8 @@ class PasswordConfig(Config):
"""Password login configuration
"""
+ section = "password"
+
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 788c39c9fb..9746bbc681 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, List
+
from synapse.util.module_loader import load_module
from ._base import Config
@@ -21,8 +23,10 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
+ section = "authproviders"
+
def read_config(self, config, **kwargs):
- self.password_providers = []
+ self.password_providers = [] # type: List[Any]
providers = []
# We want to be backwards compatible with the old `ldap_config`
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 1b932722a5..0910958649 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -18,6 +18,8 @@ from ._base import Config
class PushConfig(Config):
+ section = "push"
+
def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 33f31cf213..947f653e03 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -36,6 +36,8 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
+ section = "ratelimiting"
+
def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back
@@ -80,6 +82,12 @@ class RatelimitConfig(Config):
"federation_rr_transactions_per_room_per_second", 50
)
+ rc_admin_redaction = config.get("rc_admin_redaction")
+ if rc_admin_redaction:
+ self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
+ else:
+ self.rc_admin_redaction = None
+
def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
@@ -102,6 +110,9 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one for ratelimiting redactions by room admins. If this is not explicitly
+ # set then it uses the same ratelimiting as per rc_message. This is useful
+ # to allow room admins to deal with abuse quickly.
#
# The defaults are as shown below.
#
@@ -123,6 +134,10 @@ class RatelimitConfig(Config):
# failed_attempts:
# per_second: 0.17
# burst_count: 3
+ #
+ #rc_admin_redaction:
+ # per_second: 1
+ # burst_count: 50
# Ratelimiting settings for incoming federation
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 9548560edb..b3e3e6dda2 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -24,6 +24,8 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
+ section = "accountvalidity"
+
def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@@ -77,6 +79,8 @@ class AccountValidityConfig(Config):
class RegistrationConfig(Config):
+ section = "registration"
+
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
@@ -293,8 +297,10 @@ class RegistrationConfig(Config):
# by the Matrix Identity Service API specification:
# https://matrix.org/docs/spec/identity_service/latest
#
+ # If a delegate is specified, the config option public_baseurl must also be filled out.
+ #
account_threepid_delegates:
- #email: https://example.com # Delegate email sending to matrix.org
+ #email: https://example.com # Delegate email sending to example.org
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Users who register on this homeserver will automatically be joined
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index fdb1f246d0..d0205e14b9 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 matrix.org
+# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,9 @@
import os
from collections import namedtuple
+from typing import Dict, List
+from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
@@ -34,17 +36,6 @@ THUMBNAIL_SIZE_YAML = """\
# method: %(method)s
"""
-MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API."
-
-MISSING_LXML = """Missing lxml library. This is required for URL preview API.
-
- Install by running:
- pip install lxml
-
- Requires libxslt1-dev system package.
- """
-
-
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
@@ -71,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
- requirements = {}
+ requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@@ -87,6 +78,8 @@ def parse_thumbnail_requirements(thumbnail_sizes):
class ContentRepositoryConfig(Config):
+ section = "media"
+
def read_config(self, config, **kwargs):
# Only enable the media repo if either the media repo is enabled or the
@@ -140,7 +133,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
- self.media_storage_providers = []
+ self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to
@@ -171,16 +164,10 @@ class ContentRepositoryConfig(Config):
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
- import lxml
+ check_requirements("url_preview")
- lxml # To stop unused lint.
- except ImportError:
- raise ConfigError(MISSING_LXML)
-
- try:
- from netaddr import IPSet
- except ImportError:
- raise ConfigError(MISSING_NETADDR)
+ except DependencyException as e:
+ raise ConfigError(e.message)
if "url_preview_ip_range_blacklist" not in config:
raise ConfigError(
@@ -189,6 +176,9 @@ class ContentRepositoryConfig(Config):
"to work"
)
+ # netaddr is a dependency for url_preview
+ from netaddr import IPSet
+
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index a92693017b..7c9f05bde4 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
+ section = "roomdirectory"
+
def read_config(self, config, **kwargs):
self.enable_room_list_search = config.get("enable_room_list_search", True)
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 6a8161547a..c407e13680 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +13,50 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import re
+
from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import (
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
+from synapse.util.module_loader import load_python_module
from ._base import Config, ConfigError
+def _dict_merge(merge_dict, into_dict):
+ """Do a deep merge of two dicts
+
+ Recursively merges `merge_dict` into `into_dict`:
+ * For keys where both `merge_dict` and `into_dict` have a dict value, the values
+ are recursively merged
+ * For all other keys, the values in `into_dict` (if any) are overwritten with
+ the value from `merge_dict`.
+
+ Args:
+ merge_dict (dict): dict to merge
+ into_dict (dict): target dict
+ """
+ for k, v in merge_dict.items():
+ if k not in into_dict:
+ into_dict[k] = v
+ continue
+
+ current_val = into_dict[k]
+
+ if isinstance(v, dict) and isinstance(current_val, dict):
+ _dict_merge(v, current_val)
+ continue
+
+ # otherwise we just overwrite
+ into_dict[k] = v
+
+
class SAML2Config(Config):
+ section = "saml2"
+
def read_config(self, config, **kwargs):
self.saml2_enabled = False
@@ -26,6 +65,9 @@ class SAML2Config(Config):
if not saml2_config or not saml2_config.get("enabled", True):
return
+ if not saml2_config.get("sp_config") and not saml2_config.get("config_path"):
+ return
+
try:
check_requirements("saml2")
except DependencyException as e:
@@ -33,21 +75,40 @@ class SAML2Config(Config):
self.saml2_enabled = True
- import saml2.config
+ self.saml2_mxid_source_attribute = saml2_config.get(
+ "mxid_source_attribute", "uid"
+ )
- self.saml2_sp_config = saml2.config.SPConfig()
- self.saml2_sp_config.load(self._default_saml_config_dict())
- self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
+ self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
+ "grandfathered_mxid_source_attribute", "uid"
+ )
+
+ saml2_config_dict = self._default_saml_config_dict()
+ _dict_merge(
+ merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
+ )
config_path = saml2_config.get("config_path", None)
if config_path is not None:
- self.saml2_sp_config.load_file(config_path)
+ mod = load_python_module(config_path)
+ _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
+
+ import saml2.config
+
+ self.saml2_sp_config = saml2.config.SPConfig()
+ self.saml2_sp_config.load(saml2_config_dict)
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
saml2_config.get("saml_session_lifetime", "5m")
)
+ mapping = saml2_config.get("mxid_mapping", "hexencode")
+ try:
+ self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
+ except KeyError:
+ raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
+
def _default_saml_config_dict(self):
import saml2
@@ -55,6 +116,13 @@ class SAML2Config(Config):
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
+ required_attributes = {"uid", self.saml2_mxid_source_attribute}
+
+ optional_attributes = {"displayName"}
+ if self.saml2_grandfathered_mxid_source_attribute:
+ optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
+ optional_attributes -= required_attributes
+
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
@@ -66,8 +134,9 @@ class SAML2Config(Config):
(response_url, saml2.BINDING_HTTP_POST)
]
},
- "required_attributes": ["uid"],
- "optional_attributes": ["mail", "surname", "givenname"],
+ "required_attributes": list(required_attributes),
+ "optional_attributes": list(optional_attributes),
+ # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
}
},
}
@@ -76,12 +145,13 @@ class SAML2Config(Config):
return """\
# Enable SAML2 for registration and login. Uses pysaml2.
#
- # `sp_config` is the configuration for the pysaml2 Service Provider.
- # See pysaml2 docs for format of config.
+ # At least one of `sp_config` or `config_path` must be set in this section to
+ # enable SAML login.
#
- # Default values will be used for the 'entityid' and 'service' settings,
- # so it is not normally necessary to specify them unless you need to
- # override them.
+ # (You will probably also want to set the following options to `false` to
+ # disable the regular login/registration flows:
+ # * enable_registration
+ # * password_config.enabled
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
@@ -89,52 +159,105 @@ class SAML2Config(Config):
# the IdP to use an ACS location of
# https://<server>:<port>/_matrix/saml2/authn_response.
#
- #saml2_config:
- # sp_config:
- # # point this to the IdP's metadata. You can use either a local file or
- # # (preferably) a URL.
- # metadata:
- # #local: ["saml2/idp.xml"]
- # remote:
- # - url: https://our_idp/metadata.xml
- #
- # # By default, the user has to go to our login page first. If you'd like to
- # # allow IdP-initiated login, set 'allow_unsolicited: True' in a
- # # 'service.sp' section:
- # #
- # #service:
- # # sp:
- # # allow_unsolicited: True
- #
- # # The examples below are just used to generate our metadata xml, and you
- # # may well not need it, depending on your setup. Alternatively you
- # # may need a whole lot more detail - see the pysaml2 docs!
- #
- # description: ["My awesome SP", "en"]
- # name: ["Test SP", "en"]
- #
- # organization:
- # name: Example com
- # display_name:
- # - ["Example co", "en"]
- # url: "http://example.com"
- #
- # contact_person:
- # - given_name: Bob
- # sur_name: "the Sysadmin"
- # email_address": ["admin@example.com"]
- # contact_type": technical
- #
- # # Instead of putting the config inline as above, you can specify a
- # # separate pysaml2 configuration file:
- # #
- # config_path: "%(config_dir_path)s/sp_conf.py"
- #
- # # the lifetime of a SAML session. This defines how long a user has to
- # # complete the authentication process, if allow_unsolicited is unset.
- # # The default is 5 minutes.
- # #
- # # saml_session_lifetime: 5m
+ saml2_config:
+ # `sp_config` is the configuration for the pysaml2 Service Provider.
+ # See pysaml2 docs for format of config.
+ #
+ # Default values will be used for the 'entityid' and 'service' settings,
+ # so it is not normally necessary to specify them unless you need to
+ # override them.
+ #
+ #sp_config:
+ # # point this to the IdP's metadata. You can use either a local file or
+ # # (preferably) a URL.
+ # metadata:
+ # #local: ["saml2/idp.xml"]
+ # remote:
+ # - url: https://our_idp/metadata.xml
+ #
+ # # By default, the user has to go to our login page first. If you'd like
+ # # to allow IdP-initiated login, set 'allow_unsolicited: True' in a
+ # # 'service.sp' section:
+ # #
+ # #service:
+ # # sp:
+ # # allow_unsolicited: true
+ #
+ # # The examples below are just used to generate our metadata xml, and you
+ # # may well not need them, depending on your setup. Alternatively you
+ # # may need a whole lot more detail - see the pysaml2 docs!
+ #
+ # description: ["My awesome SP", "en"]
+ # name: ["Test SP", "en"]
+ #
+ # organization:
+ # name: Example com
+ # display_name:
+ # - ["Example co", "en"]
+ # url: "http://example.com"
+ #
+ # contact_person:
+ # - given_name: Bob
+ # sur_name: "the Sysadmin"
+ # email_address": ["admin@example.com"]
+ # contact_type": technical
+
+ # Instead of putting the config inline as above, you can specify a
+ # separate pysaml2 configuration file:
+ #
+ #config_path: "%(config_dir_path)s/sp_conf.py"
+
+ # the lifetime of a SAML session. This defines how long a user has to
+ # complete the authentication process, if allow_unsolicited is unset.
+ # The default is 5 minutes.
+ #
+ #saml_session_lifetime: 5m
+
+ # The SAML attribute (after mapping via the attribute maps) to use to derive
+ # the Matrix ID from. 'uid' by default.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a matrix ID.
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with '.').
+ # The default is 'hexencode'.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to MXID was
+ # always calculated dynamically rather than stored in a table. For backwards-
+ # compatibility, we will look for user_ids matching such a pattern before
+ # creating a new account.
+ #
+ # This setting controls the SAML attribute which will be used for this
+ # backwards-compatibility lookup. Typically it should be 'uid', but if the
+ # attribute maps are changed, it may be necessary to change it.
+ #
+ # The default is 'uid'.
+ #
+ #grandfathered_mxid_source_attribute: upn
""" % {
"config_dir_path": config_dir_path
}
+
+
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 2abdef0971..afc4d6a4ab 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,6 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
+from typing import List
import attr
import yaml
@@ -48,8 +49,17 @@ ROOM_COMPLEXITY_TOO_GREAT = (
"to join this room."
)
+METRICS_PORT_WARNING = """\
+The metrics_port configuration option is deprecated in Synapse 0.31 in favour of
+a listener. Please see
+https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
+on how to configure the new listener.
+--------------------------------------------------------------------------------"""
+
class ServerConfig(Config):
+ section = "server"
+
def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@@ -162,6 +172,23 @@ class ServerConfig(Config):
self.mau_trial_days = config.get("mau_trial_days", 0)
+ # How long to keep redacted events in the database in unredacted form
+ # before redacting them.
+ redaction_retention_period = config.get("redaction_retention_period", "7d")
+ if redaction_retention_period is not None:
+ self.redaction_retention_period = self.parse_duration(
+ redaction_retention_period
+ )
+ else:
+ self.redaction_retention_period = None
+
+ # How long to keep entries in the `users_ips` table.
+ user_ips_max_age = config.get("user_ips_max_age", "28d")
+ if user_ips_max_age is not None:
+ self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+ else:
+ self.user_ips_max_age = None
+
# Options to disable HS
self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "")
@@ -219,7 +246,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
- self.listeners = []
+ self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
raise ConfigError(
@@ -263,7 +290,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False
)
complexity = attr.ib(
- validator=attr.validators.instance_of((int, float)), default=1.0
+ validator=attr.validators.instance_of(
+ (float, int) # type: ignore[arg-type] # noqa
+ ),
+ default=1.0,
)
complexity_error = attr.ib(
validator=attr.validators.instance_of(str),
@@ -324,14 +354,7 @@ class ServerConfig(Config):
metrics_port = config.get("metrics_port")
if metrics_port:
- logger.warn(
- (
- "The metrics_port configuration option is deprecated in Synapse 0.31 "
- "in favour of a listener. Please see "
- "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
- " on how to configure the new listener."
- )
- )
+ logger.warning(METRICS_PORT_WARNING)
self.listeners.append(
{
@@ -345,13 +368,11 @@ class ServerConfig(Config):
_check_resource_config(self.listeners)
- # An experimental option to try and periodically clean up extremities
- # by sending dummy events.
self.cleanup_extremities_with_dummy_events = config.get(
- "cleanup_extremities_with_dummy_events", False
+ "cleanup_extremities_with_dummy_events", True
)
- def has_tls_listener(self):
+ def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)
def generate_config_section(
@@ -535,6 +556,9 @@ class ServerConfig(Config):
# blacklist IP address CIDR ranges. If this option is not specified, or
# specified with an empty list, no ip range blacklist will be enforced.
#
+ # As of Synapse v1.4.0 this option also affects any outbound requests to identity
+ # servers provided by user input.
+ #
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
@@ -561,8 +585,8 @@ class ServerConfig(Config):
#
# type: the type of listener. Normally 'http', but other valid options are:
# 'manhole' (see docs/manhole.md),
- # 'metrics' (see docs/metrics-howto.rst),
- # 'replication' (see docs/workers.rst).
+ # 'metrics' (see docs/metrics-howto.md),
+ # 'replication' (see docs/workers.md).
#
# tls: set to true to enable TLS for this listener. Will use the TLS
# key/cert specified in tls_private_key_path / tls_certificate_path.
@@ -597,12 +621,12 @@ class ServerConfig(Config):
#
# media: the media API (/_matrix/media).
#
- # metrics: the metrics interface. See docs/metrics-howto.rst.
+ # metrics: the metrics interface. See docs/metrics-howto.md.
#
# openid: OpenID authentication.
#
# replication: the HTTP replication API (/_synapse/replication). See
- # docs/workers.rst.
+ # docs/workers.md.
#
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
@@ -622,7 +646,7 @@ class ServerConfig(Config):
# that unwraps TLS.
#
# If you plan to use a reverse proxy, please see
- # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
+ # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md.
#
%(unsecure_http_bindings)s
@@ -718,6 +742,19 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#allow_per_room_profiles: false
+
+ # How long to keep redacted events in unredacted form in the database. After
+ # this period redacted events get replaced with their redacted form in the DB.
+ #
+ # Defaults to `7d`. Set to `null` to disable.
+ #
+ #redaction_retention_period: 28d
+
+ # How long to track users' last seen time and IPs in the database.
+ #
+ # Defaults to `28d`. Set to `null` to disable clearing out of old rows.
+ #
+ #user_ips_max_age: 14d
"""
% locals()
)
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index eaac3d73bc..6ea2ea8869 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -59,8 +59,10 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
"""
- def __init__(self):
- super(ServerNoticesConfig, self).__init__()
+ section = "servernotices"
+
+ def __init__(self, *args):
+ super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index e40797ab50..36e0ddab5c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -19,6 +19,8 @@ from ._base import Config
class SpamCheckerConfig(Config):
+ section = "spamchecker"
+
def read_config(self, config, **kwargs):
self.spam_checker = None
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index b18ddbd1fa..62485189ea 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -25,6 +25,8 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
+ section = "stats"
+
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index b3431441b9..10a99c792e 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -19,6 +19,8 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
+ section = "thirdpartyrules"
+
def read_config(self, config, **kwargs):
self.third_party_event_rules = None
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index c0148aa95c..f06341eb67 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,6 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
+from typing import List
import six
@@ -33,7 +34,9 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
- def read_config(self, config, config_dir_path, **kwargs):
+ section = "tls"
+
+ def read_config(self, config: dict, config_dir_path: str, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@@ -57,7 +60,7 @@ class TlsConfig(Config):
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
- if self.has_tls_listener():
+ if self.root.server.has_tls_listener():
if not self.tls_certificate_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
@@ -108,10 +111,17 @@ class TlsConfig(Config):
)
# Support globs (*) in whitelist values
- self.federation_certificate_verification_whitelist = []
+ self.federation_certificate_verification_whitelist = [] # type: List[str]
for entry in fed_whitelist_entries:
+ try:
+ entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))
+ except UnicodeEncodeError:
+ raise ConfigError(
+ "IDNA domain names are not allowed in the "
+ "federation_certificate_verification_whitelist: %s" % (entry,)
+ )
+
# Convert globs to regex
- entry_regex = glob_to_regex(entry)
self.federation_certificate_verification_whitelist.append(entry_regex)
# List of custom certificate authorities for federation traffic validation
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 95e7ccb3a3..8be1346113 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.python_dependencies import DependencyException, check_requirements
+
from ._base import Config, ConfigError
class TracerConfig(Config):
+ section = "tracing"
+
def read_config(self, config, **kwargs):
opentracing_config = config.get("opentracing")
if opentracing_config is None:
@@ -32,6 +36,11 @@ class TracerConfig(Config):
if not self.opentracer_enabled:
return
+ try:
+ check_requirements("opentracing")
+ except DependencyException as e:
+ raise ConfigError(e.message)
+
# The tracer is enabled so sanitize the config
self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", [])
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index f6313e17d4..c8d19c5d6b 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -21,6 +21,8 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
+ section = "userdirectory"
+
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 2ca0e1cf70..a68a3068aa 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -16,6 +16,8 @@ from ._base import Config
class VoipConfig(Config):
+ section = "voip"
+
def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index bc0fc165e3..fef72ed974 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 matrix.org
+# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +21,8 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
+ section = "worker"
+
def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 06e63a96b5..e93f0b3705 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -15,7 +15,6 @@
import logging
-import idna
from service_identity import VerificationError
from service_identity.pyopenssl import verify_hostname, verify_ip_address
from zope.interface import implementer
@@ -114,14 +113,20 @@ class ClientTLSOptionsFactory(object):
self._no_verify_ssl_context = self._no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(self._context_info_cb)
- def get_options(self, host):
+ def get_options(self, host: bytes):
+
+ # IPolicyForHTTPS.get_options takes bytes, but we want to compare
+ # against the str whitelist. The hostnames in the whitelist are already
+ # IDNA-encoded like the hosts will be here.
+ ascii_host = host.decode("ascii")
+
# Check if certificate verification has been enabled
should_verify = self._config.federation_verify_certificates
# Check if we've disabled certificate verification for this host
if should_verify:
for regex in self._config.federation_certificate_verification_whitelist:
- if regex.match(host):
+ if regex.match(ascii_host):
should_verify = False
break
@@ -162,7 +167,7 @@ class SSLClientConnectionCreator(object):
Replaces twisted.internet.ssl.ClientTLSOptions
"""
- def __init__(self, hostname, ctx, verify_certs):
+ def __init__(self, hostname: bytes, ctx, verify_certs: bool):
self._ctx = ctx
self._verifier = ConnectionVerifier(hostname, verify_certs)
@@ -190,21 +195,16 @@ class ConnectionVerifier(object):
# This code is based on twisted.internet.ssl.ClientTLSOptions.
- def __init__(self, hostname, verify_certs):
+ def __init__(self, hostname: bytes, verify_certs):
self._verify_certs = verify_certs
- if isIPAddress(hostname) or isIPv6Address(hostname):
- self._hostnameBytes = hostname.encode("ascii")
+ _decoded = hostname.decode("ascii")
+ if isIPAddress(_decoded) or isIPv6Address(_decoded):
self._is_ip_address = True
else:
- # twisted's ClientTLSOptions falls back to the stdlib impl here if
- # idna is not installed, but points out that lacks support for
- # IDNA2008 (http://bugs.python.org/issue17305).
- #
- # We can rely on having idna.
- self._hostnameBytes = idna.encode(hostname)
self._is_ip_address = False
+ self._hostnameBytes = hostname
self._hostnameASCII = self._hostnameBytes.decode("ascii")
def verify_context_info_cb(self, ssl_connection, where):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 6ee6216660..5b22a39b7f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -879,44 +879,6 @@ class FederationClient(FederationBase):
)
@defer.inlineCallbacks
- def query_auth(self, destination, room_id, event_id, local_auth):
- """
- Params:
- destination (str)
- event_it (str)
- local_auth (list)
- """
- time_now = self._clock.time_msec()
-
- send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]}
-
- code, content = yield self.transport_layer.send_query_auth(
- destination=destination,
- room_id=room_id,
- event_id=event_id,
- content=send_content,
- )
-
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
-
- auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True, room_version=room_version
- )
-
- signed_auth.sort(key=lambda e: e.depth)
-
- ret = {
- "auth_chain": signed_auth,
- "rejects": content.get("rejects", []),
- "missing": content.get("missing", []),
- }
-
- return ret
-
- @defer.inlineCallbacks
def get_missing_events(
self,
destination,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e5f0b90aec..21e52c9695 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -36,7 +36,6 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
@@ -322,18 +321,6 @@ class FederationServer(FederationBase):
pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
- for event in auth_chain:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event.get_pdu_json(),
- self.hs.hostname,
- self.hs.config.signing_key[0],
- )
- )
-
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
@@ -669,9 +656,9 @@ class FederationServer(FederationBase):
return ret
@defer.inlineCallbacks
- def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ def on_exchange_third_party_invite_request(self, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
- origin, room_id, event_dict
+ room_id, event_dict
)
return ret
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index d46f4aaeb1..2b2ee8612a 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -38,7 +38,7 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.metrics import measure_func
+from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@@ -183,8 +183,8 @@ class FederationSender(object):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
- destinations = yield self.state.get_current_hosts_in_room(
- event.room_id, latest_event_ids=event.prev_event_ids()
+ destinations = yield self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
@@ -207,8 +207,9 @@ class FederationSender(object):
@defer.inlineCallbacks
def handle_room_events(events):
- for event in events:
- yield handle_event(event)
+ with Measure(self.clock, "handle_room_events"):
+ for event in events:
+ yield handle_event(event)
events_by_room = {}
for event in events:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 482a101c09..7b18408144 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -383,17 +383,6 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_query_auth(self, destination, room_id, event_id, content):
- path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
-
- content = yield self.client.post_json(
- destination=destination, path=path, data=content
- )
-
- return content
-
- @defer.inlineCallbacks
- @log_function
def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote
server.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 132a8fb5e6..0f16f21c2d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -165,7 +165,7 @@ class Authenticator(object):
async def _reset_retry_timings(self, origin):
try:
logger.info("Marking origin %r as up", origin)
- await self.store.set_destination_retry_timings(origin, 0, 0)
+ await self.store.set_destination_retry_timings(origin, None, 0, 0)
except Exception:
logger.exception("Error resetting retry timings on %s", origin)
@@ -575,7 +575,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(
- origin, room_id, content
+ room_id, content
)
return 200, content
@@ -765,6 +765,10 @@ class PublicRoomList(BaseFederationServlet):
else:
network_tuple = ThirdPartyInstanceID(None, None)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
data = await maybeDeferred(
self.handler.get_local_public_room_list,
limit,
@@ -800,6 +804,10 @@ class PublicRoomList(BaseFederationServlet):
if search_filter is None:
logger.warning("Nonefilter")
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
data = await self.handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index d50e691436..8f10b6adbb 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,16 +21,16 @@ from six import string_types
from twisted.internet import defer
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
-# TODO: Allow users to "knock" or simpkly join depending on rules
+# TODO: Allow users to "knock" or simply join depending on rules
# TODO: Federation admin APIs
-# TODO: is_priveged flag to users and is_public to users and rooms
+# TODO: is_privileged flag to users and is_public to users and rooms
# TODO: Audit log for admins (profile updates, membership changes, users who tried
# to join but were rejected, etc)
# TODO: Flairs
@@ -590,7 +591,18 @@ class GroupsServerHandler(object):
)
# TODO: Check if user knocked
- # TODO: Check if user is already invited
+
+ invited_users = yield self.store.get_invited_users_in_group(group_id)
+ if user_id in invited_users:
+ raise SynapseError(
+ 400, "User already invited to group", errcode=Codes.BAD_STATE
+ )
+
+ user_results = yield self.store.get_users_in_group(
+ group_id, include_private=True
+ )
+ if user_id in [user_result["user_id"] for user_result in user_results]:
+ raise SynapseError(400, "User already in group")
content = {
"profile": {"name": group["name"], "avatar_url": group["avatar_url"]},
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c29c78bd65..d15c6282fb 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -45,6 +45,7 @@ class BaseHandler(object):
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
+ self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs
@@ -53,7 +54,7 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
- def ratelimit(self, requester, update=True):
+ def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests.
Args:
@@ -62,6 +63,9 @@ class BaseHandler(object):
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
+ is_admin_redaction (bool): Whether this is a room admin/moderator
+ redacting an event. If so then we may apply different
+ ratelimits depending on config.
Raises:
LimitExceededError if the request should be ratelimited
@@ -90,16 +94,33 @@ class BaseHandler(object):
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
- messages_per_second = self.hs.config.rc_message.per_second
- burst_count = self.hs.config.rc_message.burst_count
-
- allowed, time_allowed = self.ratelimiter.can_do_action(
- user_id,
- time_now,
- rate_hz=messages_per_second,
- burst_count=burst_count,
- update=update,
- )
+ # We default to different values if this is an admin redaction and
+ # the config is set
+ if is_admin_redaction and self.hs.config.rc_admin_redaction:
+ messages_per_second = self.hs.config.rc_admin_redaction.per_second
+ burst_count = self.hs.config.rc_admin_redaction.burst_count
+ else:
+ messages_per_second = self.hs.config.rc_message.per_second
+ burst_count = self.hs.config.rc_message.burst_count
+
+ if is_admin_redaction and self.hs.config.rc_admin_redaction:
+ # If we have separate config for admin redactions we use a separate
+ # ratelimiter
+ allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
+ user_id,
+ time_now,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
+ )
+ else:
+ allowed, time_allowed = self.ratelimiter.can_do_action(
+ user_id,
+ time_now,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
+ )
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d0c0142740..333eb30625 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -21,10 +21,8 @@ import unicodedata
import attr
import bcrypt
import pymacaroons
-from canonicaljson import json
from twisted.internet import defer
-from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
@@ -38,7 +36,8 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.types import UserID
@@ -58,13 +57,13 @@ class AuthHandler(BaseHandler):
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs)
- self.checkers = {
- LoginType.RECAPTCHA: self._check_recaptcha,
- LoginType.EMAIL_IDENTITY: self._check_email_identity,
- LoginType.MSISDN: self._check_msisdn,
- LoginType.DUMMY: self._check_dummy_auth,
- LoginType.TERMS: self._check_terms_auth,
- }
+
+ self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
+ for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
+ inst = auth_checker_class(hs)
+ if inst.is_enabled():
+ self.checkers[inst.AUTH_TYPE] = inst
+
self.bcrypt_rounds = hs.config.bcrypt_rounds
# This is not a cache per se, but a store of all current sessions that
@@ -158,6 +157,14 @@ class AuthHandler(BaseHandler):
return params
+ def get_enabled_auth_types(self):
+ """Return the enabled user-interactive authentication types
+
+ Returns the UI-Auth types which are supported by the homeserver's current
+ config.
+ """
+ return self.checkers.keys()
+
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
@@ -292,7 +299,7 @@ class AuthHandler(BaseHandler):
sess["creds"] = {}
creds = sess["creds"]
- result = yield self.checkers[stagetype](authdict, clientip)
+ result = yield self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
@@ -363,7 +370,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
- res = yield checker(authdict, clientip=clientip)
+ res = yield checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
@@ -376,116 +383,6 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id
- @defer.inlineCallbacks
- def _check_recaptcha(self, authdict, clientip, **kwargs):
- try:
- user_response = authdict["response"]
- except KeyError:
- # Client tried to provide captcha but didn't give the parameter:
- # bad request.
- raise LoginError(
- 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
- )
-
- logger.info(
- "Submitting recaptcha response %s with remoteip %s", user_response, clientip
- )
-
- # TODO: get this from the homeserver rather than creating a new one for
- # each request
- try:
- client = self.hs.get_simple_http_client()
- resp_body = yield client.post_urlencoded_get_json(
- self.hs.config.recaptcha_siteverify_api,
- args={
- "secret": self.hs.config.recaptcha_private_key,
- "response": user_response,
- "remoteip": clientip,
- },
- )
- except PartialDownloadError as pde:
- # Twisted is silly
- data = pde.response
- resp_body = json.loads(data)
-
- if "success" in resp_body:
- # Note that we do NOT check the hostname here: we explicitly
- # intend the CAPTCHA to be presented by whatever client the
- # user is using, we just care that they have completed a CAPTCHA.
- logger.info(
- "%s reCAPTCHA from hostname %s",
- "Successful" if resp_body["success"] else "Failed",
- resp_body.get("hostname"),
- )
- if resp_body["success"]:
- return True
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
- def _check_email_identity(self, authdict, **kwargs):
- return self._check_threepid("email", authdict, **kwargs)
-
- def _check_msisdn(self, authdict, **kwargs):
- return self._check_threepid("msisdn", authdict)
-
- def _check_dummy_auth(self, authdict, **kwargs):
- return defer.succeed(True)
-
- def _check_terms_auth(self, authdict, **kwargs):
- return defer.succeed(True)
-
- @defer.inlineCallbacks
- def _check_threepid(self, medium, authdict, **kwargs):
- if "threepid_creds" not in authdict:
- raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
-
- threepid_creds = authdict["threepid_creds"]
-
- identity_handler = self.hs.get_handlers().identity_handler
-
- logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- threepid = yield identity_handler.threepid_from_creds(threepid_creds)
- elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- row = yield self.store.get_threepid_validation_session(
- medium,
- threepid_creds["client_secret"],
- sid=threepid_creds["sid"],
- validated=True,
- )
-
- threepid = (
- {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
- }
- if row
- else None
- )
-
- if row:
- # Valid threepid returned, delete from the db
- yield self.store.delete_threepid_session(threepid_creds["sid"])
- else:
- raise SynapseError(
- 400, "Password resets are not enabled on this homeserver"
- )
-
- if not threepid:
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
- if threepid["medium"] != medium:
- raise LoginError(
- 401,
- "Expecting threepid of type '%s', got '%s'"
- % (medium, threepid["medium"]),
- errcode=Codes.UNAUTHORIZED,
- )
-
- threepid["threepid_creds"] = authdict["threepid_creds"]
-
- return threepid
-
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 5f804d1f13..63267a0a4c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -73,7 +73,9 @@ class DeactivateAccountHandler(BaseHandler):
# unbinding
identity_server_supports_unbinding = True
- threepids = yield self.store.user_get_threepids(user_id)
+ # Retrieve the 3PIDs this user has bound to an identity server
+ threepids = yield self.store.user_get_bound_threepids(user_id)
+
for threepid in threepids:
try:
result = yield self._identity_handler.try_unbind_threepid(
@@ -118,6 +120,10 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+ # Reject all pending invites for the user, so that the user doesn't show up in the
+ # "invited" section of rooms' members list.
+ yield self._reject_pending_invites_for_user(user_id)
+
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id)
@@ -127,6 +133,39 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding
+ @defer.inlineCallbacks
+ def _reject_pending_invites_for_user(self, user_id):
+ """Reject pending invites addressed to a given user ID.
+
+ Args:
+ user_id (str): The user ID to reject pending invites for.
+ """
+ user = UserID.from_string(user_id)
+ pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+
+ for room in pending_invites:
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room.room_id,
+ "leave",
+ ratelimit=False,
+ require_consent=False,
+ )
+ logger.info(
+ "Rejected invite for deactivated user %r in room %r",
+ user_id,
+ room.room_id,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to reject invite for user %r in room %r:"
+ " ignoring and continuing",
+ user_id,
+ room.room_id,
+ )
+
def _start_user_parting(self):
"""
Start the process that goes through the table of users
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 6bf3ef49a8..5ea54f60be 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -329,16 +329,10 @@ class E2eKeysHandler(object):
results = yield self.store.get_e2e_device_keys(local_query)
- # Build the result structure, un-jsonify the results, and add the
- # "unsigned" section
+ # Build the result structure
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
- r = dict(device_info["keys"])
- r["unsigned"] = {}
- display_name = device_info["device_display_name"]
- if display_name is not None:
- r["unsigned"]["device_display_name"] = display_name
- result_dict[user_id][device_id] = r
+ result_dict[user_id][device_id] = device_info
log_kv(results)
return result_dict
@@ -750,7 +744,7 @@ class E2eKeysHandler(object):
)
try:
- stored_device = devices[device_id]["keys"]
+ stored_device = devices[device_id]
except KeyError:
raise NotFoundError("Unknown device")
if self_signing_key_id in stored_device.get("signatures", {}).get(
@@ -800,14 +794,14 @@ class E2eKeysHandler(object):
_, signing_device_id = signing_key_id.split(":", 1)
if (
signing_device_id not in devices
- or signing_key_id not in devices[signing_device_id]["keys"]["keys"]
+ or signing_key_id not in devices[signing_device_id]["keys"]
):
# signed by an unknown device, or the
# device does not have the key
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
# get the key and check the signature
- pubkey = devices[signing_device_id]["keys"]["keys"][signing_key_id]
+ pubkey = devices[signing_device_id]["keys"][signing_key_id]
verify_key = decode_verify_key_bytes(signing_key_id, decode_base64(pubkey))
_check_device_signature(
user_id, verify_key, signed_master_key, stored_master_key
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index a9d80f708c..0cea445f0d 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -352,8 +352,8 @@ class E2eRoomKeysHandler(object):
A deferred of an empty dict.
"""
if "version" not in version_info:
- raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM)
- if version_info["version"] != version:
+ version_info["version"] = version
+ elif version_info["version"] != version:
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 538b16efd6..57f661f16e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2181,103 +2181,10 @@ class FederationHandler(BaseHandler):
auth_events.update(new_state)
- different_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
yield self._update_context_for_auth_events(
event, context, auth_events, event_key
)
- if not different_auth:
- # we're done
- return
-
- logger.info(
- "auth_events still refers to events which are not in the calculated auth "
- "chain after state resolution: %s",
- different_auth,
- )
-
- # Only do auth resolution if we have something new to say.
- # We can't prove an auth failure.
- do_resolution = False
-
- for e_id in different_auth:
- if e_id in have_events:
- if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
- do_resolution = True
- break
-
- if not do_resolution:
- logger.info(
- "Skipping auth resolution due to lack of provable rejection reasons"
- )
- return
-
- logger.info("Doing auth resolution")
-
- prev_state_ids = yield context.get_prev_state_ids(self.store)
-
- # 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids)
- local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True)
-
- try:
- # 2. Get remote difference.
- try:
- result = yield self.federation_client.query_auth(
- origin, event.room_id, event.event_id, local_auth_chain
- )
- except RequestSendFailed as e:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to query auth from remote: %s", e)
- return
-
- seen_remotes = yield self.store.have_seen_events(
- [e.event_id for e in result["auth_chain"]]
- )
-
- # 3. Process any remote auth chain events we haven't seen.
- for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes:
- continue
-
- if ev.event_id == event.event_id:
- continue
-
- try:
- auth_ids = ev.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in result["auth_chain"]
- if e.event_id in auth_ids or event.type == EventTypes.Create
- }
- ev.internal_metadata.outlier = True
-
- logger.debug(
- "do_auth %s different_auth: %s", event.event_id, e.event_id
- )
-
- yield self._handle_new_event(origin, ev, auth_events=auth)
-
- if ev.event_id in event_auth_events:
- auth_events[(ev.type, ev.state_key)] = ev
- except AuthError:
- pass
-
- except Exception:
- # FIXME:
- logger.exception("Failed to query auth chain")
-
- # 4. Look at rejects and their proofs.
- # TODO.
-
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key
- )
-
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
@@ -2444,15 +2351,6 @@ class FederationHandler(BaseHandler):
reason_map[e.event_id] = reason
- if reason == RejectedReason.AUTH_ERROR:
- pass
- elif reason == RejectedReason.REPLACED:
- # TODO: Get proof
- pass
- elif reason == RejectedReason.NOT_ANCESTOR:
- # TODO: Get proof.
- pass
-
logger.debug("construct_auth_difference returning")
return {
@@ -2530,12 +2428,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ def on_exchange_third_party_invite_request(self, room_id, event_dict):
"""Handle an exchange_third_party_invite request from a remote server
The remote server will call this when it wants to turn a 3pid invite
into a normal m.room.member invite.
+ Args:
+ room_id (str): The ID of the room.
+
+ event_dict (dict[str, Any]): Dictionary containing the event body.
+
Returns:
Deferred: resolves (to None)
"""
@@ -2565,7 +2468,7 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check_from_context(room_version, event, context)
+ yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
@@ -2594,7 +2497,12 @@ class FederationHandler(BaseHandler):
original_invite_id, allow_none=True
)
if original_invite:
- display_name = original_invite.content["display_name"]
+ # If the m.room.third_party_invite event's content is empty, it means the
+ # invite has been revoked. In this case, we don't have to raise an error here
+ # because the auth check will fail on the invite (because it's not able to
+ # fetch public keys from the m.room.third_party_invite event's content, which
+ # is empty).
+ display_name = original_invite.content.get("display_name")
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 45db1c1c06..ba99ddf76d 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -18,160 +18,150 @@
"""Utilities for interacting with Identity Servers"""
import logging
+import urllib
from canonicaljson import json
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json
+from unpaddedbase64 import decode_base64
from twisted.internet import defer
+from twisted.internet.error import TimeoutError
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
SynapseError,
)
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.client import SimpleHttpClient
+from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import random_string
from ._base import BaseHandler
logger = logging.getLogger(__name__)
+id_server_scheme = "https://"
+
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
- self.http_client = hs.get_simple_http_client()
+ self.http_client = SimpleHttpClient(hs)
+ # We create a blacklisting instance of SimpleHttpClient for contacting identity
+ # servers specified by clients
+ self.blacklisting_http_client = SimpleHttpClient(
+ hs, ip_blacklist=hs.config.federation_ip_range_blacklist
+ )
self.federation_http_client = hs.get_http_client()
self.hs = hs
- def _extract_items_from_creds_dict(self, creds):
+ @defer.inlineCallbacks
+ def threepid_from_creds(self, id_server, creds):
"""
- Retrieve entries from a "credentials" dictionary
+ Retrieve and validate a threepid identifier from a "credentials" dictionary against a
+ given identity server
Args:
- creds (dict[str, str]): Dictionary of credentials that contain the following keys:
+ id_server (str): The identity server to validate 3PIDs against. Must be a
+ complete URL including the protocol (http(s)://)
+
+ creds (dict[str, str]): Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
- * id_server|idServer: the domain of the identity server to query
- * id_access_token: The access token to authenticate to the identity
- server with.
+ * sid: The ID of the validation session
Returns:
- tuple(str, str, str|None): A tuple containing the client_secret, the id_server,
- and the id_access_token value if available.
+ Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
+ the /getValidated3pid endpoint of the Identity Service API, or None if the
+ threepid was not found
"""
client_secret = creds.get("client_secret") or creds.get("clientSecret")
if not client_secret:
raise SynapseError(
- 400, "No client_secret in creds", errcode=Codes.MISSING_PARAM
+ 400, "Missing param client_secret in creds", errcode=Codes.MISSING_PARAM
)
-
- id_server = creds.get("id_server") or creds.get("idServer")
- if not id_server:
+ session_id = creds.get("sid")
+ if not session_id:
raise SynapseError(
- 400, "No id_server in creds", errcode=Codes.MISSING_PARAM
+ 400, "Missing param session_id in creds", errcode=Codes.MISSING_PARAM
)
- id_access_token = creds.get("id_access_token")
- return client_secret, id_server, id_access_token
+ query_params = {"sid": session_id, "client_secret": client_secret}
- @defer.inlineCallbacks
- def threepid_from_creds(self, creds, use_v2=True):
- """
- Retrieve and validate a threepid identitier from a "credentials" dictionary
-
- Args:
- creds (dict[str, str]): Dictionary of credentials that contain the following keys:
- * client_secret|clientSecret: A unique secret str provided by the client
- * id_server|idServer: the domain of the identity server to query
- * id_access_token: The access token to authenticate to the identity
- server with. Required if use_v2 is true
- use_v2 (bool): Whether to use v2 Identity Service API endpoints
-
- Returns:
- Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
- the /getValidated3pid endpoint of the Identity Service API, or None if the
- threepid was not found
- """
- client_secret, id_server, id_access_token = self._extract_items_from_creds_dict(
- creds
- )
-
- # If an id_access_token is not supplied, force usage of v1
- if id_access_token is None:
- use_v2 = False
-
- query_params = {"sid": creds["sid"], "client_secret": client_secret}
-
- # Decide which API endpoint URLs and query parameters to use
- if use_v2:
- url = "https://%s%s" % (
- id_server,
- "/_matrix/identity/v2/3pid/getValidated3pid",
- )
- query_params["id_access_token"] = id_access_token
- else:
- url = "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/3pid/getValidated3pid",
- )
+ url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
try:
data = yield self.http_client.get_json(url, query_params)
- return data if "medium" in data else None
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
- if e.code != 404 or not use_v2:
- # Generic failure
- logger.info("getValidated3pid failed with Matrix error: %r", e)
- raise e.to_synapse_error()
+ logger.info(
+ "%s returned %i for threepid validation for: %s",
+ id_server,
+ e.code,
+ creds,
+ )
+ return None
+
+ # Old versions of Sydent return a 200 http code even on a failed validation
+ # check. Thus, in addition to the HttpResponseException check above (which
+ # checks for non-200 errors), we need to make sure validation_session isn't
+ # actually an error, identified by the absence of a "medium" key
+ # See https://github.com/matrix-org/sydent/issues/215 for details
+ if "medium" in data:
+ return data
- # This identity server is too old to understand Identity Service API v2
- # Attempt v1 endpoint
- logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", url)
- return (yield self.threepid_from_creds(creds, use_v2=False))
+ logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ return None
@defer.inlineCallbacks
- def bind_threepid(self, creds, mxid, use_v2=True):
+ def bind_threepid(
+ self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
+ ):
"""Bind a 3PID to an identity server
Args:
- creds (dict[str, str]): Dictionary of credentials that contain the following keys:
- * client_secret|clientSecret: A unique secret str provided by the client
- * id_server|idServer: the domain of the identity server to query
- * id_access_token: The access token to authenticate to the identity
- server with. Required if use_v2 is true
+ client_secret (str): A unique secret provided by the client
+
+ sid (str): The ID of the validation session
+
mxid (str): The MXID to bind the 3PID to
- use_v2 (bool): Whether to use v2 Identity Service API endpoints
+
+ id_server (str): The domain of the identity server to query
+
+ id_access_token (str): The access token to authenticate to the identity
+ server with, if necessary. Required if use_v2 is true
+
+ use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True
Returns:
Deferred[dict]: The response from the identity server
"""
- logger.debug("binding threepid %r to %s", creds, mxid)
-
- client_secret, id_server, id_access_token = self._extract_items_from_creds_dict(
- creds
- )
-
- sid = creds.get("sid")
- if not sid:
- raise SynapseError(
- 400, "No sid in three_pid_creds", errcode=Codes.MISSING_PARAM
- )
+ logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
# If an id_access_token is not supplied, force usage of v1
if id_access_token is None:
use_v2 = False
# Decide which API endpoint URLs to use
+ headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
- bind_data["id_access_token"] = id_access_token
+ headers["Authorization"] = create_id_access_token_header(id_access_token)
else:
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
try:
- data = yield self.http_client.post_json_get_json(bind_url, bind_data)
- logger.debug("bound threepid %r to %s", creds, mxid)
+ # Use the blacklisting http client as this call is only to identity servers
+ # provided by a client
+ data = yield self.blacklisting_http_client.post_json_get_json(
+ bind_url, bind_data, headers=headers
+ )
# Remember where we bound the threepid
yield self.store.add_user_bound_threepid(
@@ -186,12 +176,17 @@ class IdentityHandler(BaseHandler):
if e.code != 404 or not use_v2:
logger.error("3PID bind failed with Matrix error: %r", e)
raise e.to_synapse_error()
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
- return (yield self.bind_threepid(creds, mxid, use_v2=False))
+ res = yield self.bind_threepid(
+ client_secret, sid, mxid, id_server, id_access_token, use_v2=False
+ )
+ return res
@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
@@ -267,7 +262,11 @@ class IdentityHandler(BaseHandler):
headers = {b"Authorization": auth_headers}
try:
- yield self.http_client.post_json_get_json(url, content, headers)
+ # Use the blacklisting http client as this call is only to identity servers
+ # provided by a client
+ yield self.blacklisting_http_client.post_json_get_json(
+ url, content, headers
+ )
changed = True
except HttpResponseException as e:
changed = False
@@ -276,7 +275,9 @@ class IdentityHandler(BaseHandler):
logger.warn("Received %d response while unbinding threepid", e.code)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
- raise SynapseError(502, "Failed to contact identity server")
+ raise SynapseError(500, "Failed to contact identity server")
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
yield self.store.remove_user_bound_threepid(
user_id=mxid,
@@ -335,6 +336,15 @@ class IdentityHandler(BaseHandler):
# Generate a session id
session_id = random_string(16)
+ if next_link:
+ # Manipulate the next_link to add the sid, because the caller won't get
+ # it until we send a response, by which time we've sent the mail.
+ if "?" in next_link:
+ next_link += "&"
+ else:
+ next_link += "?"
+ next_link += "sid=" + urllib.parse.quote(session_id)
+
# Generate a new validation token
token = random_string(32)
@@ -409,6 +419,8 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
@defer.inlineCallbacks
def requestMsisdnToken(
@@ -457,7 +469,476 @@ class IdentityHandler(BaseHandler):
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params,
)
- return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+
+ assert self.hs.config.public_baseurl
+
+ # we need to tell the client to send the token back to us, since it doesn't
+ # otherwise know where to send it, so add submit_url response parameter
+ # (see also MSC2078)
+ data["submit_url"] = (
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/add_threepid/msisdn/submit_token"
+ )
+ return data
+
+ @defer.inlineCallbacks
+ def validate_threepid_session(self, client_secret, sid):
+ """Validates a threepid session with only the client secret and session ID
+ Tries validating against any configured account_threepid_delegates as well as locally.
+
+ Args:
+ client_secret (str): A secret provided by the client
+
+ sid (str): The ID of the session
+
+ Returns:
+ Dict[str, str|int] if validation was successful, otherwise None
+ """
+ # XXX: We shouldn't need to keep wrapping and unwrapping this value
+ threepid_creds = {"client_secret": client_secret, "sid": sid}
+
+ # We don't actually know which medium this 3PID is. Thus we first assume it's email,
+ # and if validation fails we try msisdn
+ validation_session = None
+
+ # Try to validate as email
+ if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ # Ask our delegated email identity server
+ validation_session = yield self.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_email, threepid_creds
+ )
+ elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ # Get a validated session matching these details
+ validation_session = yield self.store.get_threepid_validation_session(
+ "email", client_secret, sid=sid, validated=True
+ )
+
+ if validation_session:
+ return validation_session
+
+ # Try to validate as msisdn
+ if self.hs.config.account_threepid_delegate_msisdn:
+ # Ask our delegated msisdn identity server
+ validation_session = yield self.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ )
+
+ return validation_session
+
+ @defer.inlineCallbacks
+ def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ """Proxy a POST submitToken request to an identity server for verification purposes
+
+ Args:
+ id_server (str): The identity server URL to contact
+
+ client_secret (str): Secret provided by the client
+
+ sid (str): The ID of the session
+
+ token (str): The verification token
+
+ Raises:
+ SynapseError: If we failed to contact the identity server
+
+ Returns:
+ Deferred[dict]: The response dict from the identity server
+ """
+ body = {"client_secret": client_secret, "sid": sid, "token": token}
+
+ try:
+ return (
+ yield self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+ body,
+ )
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
+ raise SynapseError(400, "Error contacting the identity server")
+
+ @defer.inlineCallbacks
+ def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+ id_access_token (str|None): The access token to authenticate to the identity
+ server with
+
+ Returns:
+ str|None: the matrix ID of the 3pid, or None if it is not recognized.
+ """
+ if id_access_token is not None:
+ try:
+ results = yield self._lookup_3pid_v2(
+ id_server, id_access_token, medium, address
+ )
+ return results
+
+ except Exception as e:
+ # Catch HttpResponseExcept for a non-200 response code
+ # Check if this identity server does not know about v2 lookups
+ if isinstance(e, HttpResponseException) and e.code == 404:
+ # This is an old identity server that does not yet support v2 lookups
+ logger.warning(
+ "Attempted v2 lookup on v1 identity server %s. Falling "
+ "back to v1",
+ id_server,
+ )
+ else:
+ logger.warning("Error when looking up hashing details: %s", e)
+ return None
+
+ return (yield self._lookup_3pid_v1(id_server, medium, address))
+
+ @defer.inlineCallbacks
+ def _lookup_3pid_v1(self, id_server, medium, address):
+ """Looks up a 3pid in the passed identity server using v1 lookup.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ str: the matrix ID of the 3pid, or None if it is not recognized.
+ """
+ try:
+ data = yield self.blacklisting_http_client.get_json(
+ "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+ return data["mxid"]
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except IOError as e:
+ logger.warning("Error from v1 identity server lookup: %s" % (e,))
+
+ return None
+
+ @defer.inlineCallbacks
+ def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ """Looks up a 3pid in the passed identity server using v2 lookup.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ id_access_token (str): The access token to authenticate to the identity server with
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
+ """
+ # Check what hashing details are supported by this identity server
+ try:
+ hash_details = yield self.blacklisting_http_client.get_json(
+ "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ {"access_token": id_access_token},
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+
+ if not isinstance(hash_details, dict):
+ logger.warning(
+ "Got non-dict object when checking hash details of %s%s: %s",
+ id_server_scheme,
+ id_server,
+ hash_details,
+ )
+ raise SynapseError(
+ 400,
+ "Non-dict object from %s%s during v2 hash_details request: %s"
+ % (id_server_scheme, id_server, hash_details),
+ )
+
+ # Extract information from hash_details
+ supported_lookup_algorithms = hash_details.get("algorithms")
+ lookup_pepper = hash_details.get("lookup_pepper")
+ if (
+ not supported_lookup_algorithms
+ or not isinstance(supported_lookup_algorithms, list)
+ or not lookup_pepper
+ or not isinstance(lookup_pepper, str)
+ ):
+ raise SynapseError(
+ 400,
+ "Invalid hash details received from identity server %s%s: %s"
+ % (id_server_scheme, id_server, hash_details),
+ )
+
+ # Check if any of the supported lookup algorithms are present
+ if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
+ # Perform a hashed lookup
+ lookup_algorithm = LookupAlgorithm.SHA256
+
+ # Hash address, medium and the pepper with sha256
+ to_hash = "%s %s %s" % (address, medium, lookup_pepper)
+ lookup_value = sha256_and_url_safe_base64(to_hash)
+
+ elif LookupAlgorithm.NONE in supported_lookup_algorithms:
+ # Perform a non-hashed lookup
+ lookup_algorithm = LookupAlgorithm.NONE
+
+ # Combine together plaintext address and medium
+ lookup_value = "%s %s" % (address, medium)
+
+ else:
+ logger.warning(
+ "None of the provided lookup algorithms of %s are supported: %s",
+ id_server,
+ supported_lookup_algorithms,
+ )
+ raise SynapseError(
+ 400,
+ "Provided identity server does not support any v2 lookup "
+ "algorithms that this homeserver supports.",
+ )
+
+ # Authenticate with identity server given the access token from the client
+ headers = {"Authorization": create_id_access_token_header(id_access_token)}
+
+ try:
+ lookup_results = yield self.blacklisting_http_client.post_json_get_json(
+ "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ {
+ "addresses": [lookup_value],
+ "algorithm": lookup_algorithm,
+ "pepper": lookup_pepper,
+ },
+ headers=headers,
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except Exception as e:
+ logger.warning("Error when performing a v2 3pid lookup: %s", e)
+ raise SynapseError(
+ 500, "Unknown error occurred during identity server lookup"
+ )
+
+ # Check for a mapping from what we looked up to an MXID
+ if "mappings" not in lookup_results or not isinstance(
+ lookup_results["mappings"], dict
+ ):
+ logger.warning("No results from 3pid lookup")
+ return None
+
+ # Return the MXID if it's available, or None otherwise
+ mxid = lookup_results["mappings"].get(lookup_value)
+ return mxid
+
+ @defer.inlineCallbacks
+ def _verify_any_signature(self, data, server_hostname):
+ if server_hostname not in data["signatures"]:
+ raise AuthError(401, "No signature from server %s" % (server_hostname,))
+ for key_name, signature in data["signatures"][server_hostname].items():
+ try:
+ key_data = yield self.blacklisting_http_client.get_json(
+ "%s%s/_matrix/identity/api/v1/pubkey/%s"
+ % (id_server_scheme, server_hostname, key_name)
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ if "public_key" not in key_data:
+ raise AuthError(
+ 401, "No public key named %s from %s" % (key_name, server_hostname)
+ )
+ verify_signed_json(
+ data,
+ server_hostname,
+ decode_verify_key_bytes(
+ key_name, decode_base64(key_data["public_key"])
+ ),
+ )
+ return
+
+ @defer.inlineCallbacks
+ def ask_id_server_for_third_party_invite(
+ self,
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ inviter_user_id,
+ room_alias,
+ room_avatar_url,
+ room_join_rules,
+ room_name,
+ inviter_display_name,
+ inviter_avatar_url,
+ id_access_token=None,
+ ):
+ """
+ Asks an identity server for a third party invite.
+
+ Args:
+ requester (Requester)
+ id_server (str): hostname + optional port for the identity server.
+ medium (str): The literal string "email".
+ address (str): The third party address being invited.
+ room_id (str): The ID of the room to which the user is invited.
+ inviter_user_id (str): The user ID of the inviter.
+ room_alias (str): An alias for the room, for cosmetic notifications.
+ room_avatar_url (str): The URL of the room's avatar, for cosmetic
+ notifications.
+ room_join_rules (str): The join rules of the email (e.g. "public").
+ room_name (str): The m.room.name of the room.
+ inviter_display_name (str): The current display name of the
+ inviter.
+ inviter_avatar_url (str): The URL of the inviter's avatar.
+ id_access_token (str|None): The access token to authenticate to the identity
+ server with
+
+ Returns:
+ A deferred tuple containing:
+ token (str): The token which must be signed to prove authenticity.
+ public_keys ([{"public_key": str, "key_validity_url": str}]):
+ public_key is a base64-encoded ed25519 public key.
+ fallback_public_key: One element from public_keys.
+ display_name (str): A user-friendly name to represent the invited
+ user.
+ """
+ invite_config = {
+ "medium": medium,
+ "address": address,
+ "room_id": room_id,
+ "room_alias": room_alias,
+ "room_avatar_url": room_avatar_url,
+ "room_join_rules": room_join_rules,
+ "room_name": room_name,
+ "sender": inviter_user_id,
+ "sender_display_name": inviter_display_name,
+ "sender_avatar_url": inviter_avatar_url,
+ }
+
+ # Add the identity service access token to the JSON body and use the v2
+ # Identity Service endpoints if id_access_token is present
+ data = None
+ base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+
+ if id_access_token:
+ key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_scheme,
+ id_server,
+ )
+
+ # Attempt a v2 lookup
+ url = base_url + "/v2/store-invite"
+ try:
+ data = yield self.blacklisting_http_client.post_json_get_json(
+ url,
+ invite_config,
+ {"Authorization": create_id_access_token_header(id_access_token)},
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ if e.code != 404:
+ logger.info("Failed to POST %s with JSON: %s", url, e)
+ raise e
+
+ if data is None:
+ key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_scheme,
+ id_server,
+ )
+ url = base_url + "/api/v1/store-invite"
+
+ try:
+ data = yield self.blacklisting_http_client.post_json_get_json(
+ url, invite_config
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ logger.warning(
+ "Error trying to call /store-invite on %s%s: %s",
+ id_server_scheme,
+ id_server,
+ e,
+ )
+
+ if data is None:
+ # Some identity servers may only support application/x-www-form-urlencoded
+ # types. This is especially true with old instances of Sydent, see
+ # https://github.com/matrix-org/sydent/pull/170
+ try:
+ data = yield self.blacklisting_http_client.post_urlencoded_get_json(
+ url, invite_config
+ )
+ except HttpResponseException as e:
+ logger.warning(
+ "Error calling /store-invite on %s%s with fallback "
+ "encoding: %s",
+ id_server_scheme,
+ id_server,
+ e,
+ )
+ raise e
+
+ # TODO: Check for success
+ token = data["token"]
+ public_keys = data.get("public_keys", [])
+ if "public_key" in data:
+ fallback_public_key = {
+ "public_key": data["public_key"],
+ "key_validity_url": key_validity_url,
+ }
+ else:
+ fallback_public_key = public_keys[0]
+
+ if not public_keys:
+ public_keys.append(fallback_public_key)
+ display_name = data["display_name"]
+ return token, public_keys, fallback_public_key, display_name
+
+
+def create_id_access_token_header(id_access_token):
+ """Create an Authorization header for passing to SimpleHttpClient as the header value
+ of an HTTP request.
+
+ Args:
+ id_access_token (str): An identity server access token.
+
+ Returns:
+ list[str]: The ascii-encoded bearer token encased in a list.
+ """
+ # Prefix with Bearer
+ bearer_token = "Bearer %s" % id_access_token
+
+ # Encode headers to standard ascii
+ bearer_token.encode("ascii")
+
+ # Return as a list as that's how SimpleHttpClient takes header values
+ return [bearer_token]
+
+
+class LookupAlgorithm:
+ """
+ Supported hashing algorithms when performing a 3PID lookup.
+
+ SHA256 - Hashing an (address, medium, pepper) combo with sha256, then url-safe base64
+ encoding
+ NONE - Not performing any hashing. Simply sending an (address, medium) combo in plaintext
+ """
+
+ SHA256 = "sha256"
+ NONE = "none"
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 111f7c7e2f..0f8cce8ffe 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -222,6 +222,13 @@ class MessageHandler(object):
}
+# The duration (in ms) after which rooms should be removed
+# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
+# to generate a dummy event for them once more)
+#
+_DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
+
+
class EventCreationHandler(object):
def __init__(self, hs):
self.hs = hs
@@ -258,6 +265,13 @@ class EventCreationHandler(object):
self.config.block_events_without_consent_error
)
+ # Rooms which should be excluded from dummy insertion. (For instance,
+ # those without local users who can send events into the room).
+ #
+ # map from room id to time-of-last-attempt.
+ #
+ self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int]
+
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
# config options, but *only* if we have a configuration for which we are
# going to need it.
@@ -729,7 +743,27 @@ class EventCreationHandler(object):
assert not self.config.worker_app
if ratelimit:
- yield self.base_handler.ratelimit(requester)
+ # We check if this is a room admin redacting an event so that we
+ # can apply different ratelimiting. We do this by simply checking
+ # it's not a self-redaction (to avoid having to look up whether the
+ # user is actually admin or not).
+ is_admin_redaction = False
+ if event.type == EventTypes.Redaction:
+ original_event = yield self.store.get_event(
+ event.redacts,
+ check_redacted=False,
+ get_prev_content=False,
+ allow_rejected=False,
+ allow_none=True,
+ )
+
+ is_admin_redaction = (
+ original_event and event.sender != original_event.sender
+ )
+
+ yield self.base_handler.ratelimit(
+ requester, is_admin_redaction=is_admin_redaction
+ )
yield self.base_handler.maybe_kick_guest_users(event, context)
@@ -868,9 +902,11 @@ class EventCreationHandler(object):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
-
+ self._expire_rooms_to_exclude_from_dummy_event_insertion()
room_ids = yield self.store.get_rooms_with_many_extremities(
- min_count=10, limit=5
+ min_count=10,
+ limit=5,
+ room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
)
for room_id in room_ids:
@@ -884,32 +920,61 @@ class EventCreationHandler(object):
members = yield self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
+ dummy_event_sent = False
+ for user_id in members:
+ if not self.hs.is_mine_id(user_id):
+ continue
+ requester = create_requester(user_id)
+ try:
+ event, context = yield self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
- user_id = None
- for member in members:
- if self.hs.is_mine_id(member):
- user_id = member
- break
-
- if not user_id:
- # We don't have a joined user.
- # TODO: We should do something here to stop the room from
- # appearing next time.
- continue
+ event.internal_metadata.proactively_send = False
- requester = create_requester(user_id)
+ yield self.send_nonmember_event(
+ requester, event, context, ratelimit=False
+ )
+ dummy_event_sent = True
+ break
+ except ConsentNotGivenError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of consent. Will try another user" % (room_id, user_id)
+ )
+ except AuthError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of power. Will try another user" % (room_id, user_id)
+ )
- event, context = yield self.create_event(
- requester,
- {
- "type": "org.matrix.dummy_event",
- "content": {},
- "room_id": room_id,
- "sender": user_id,
- },
- prev_events_and_hashes=prev_events_and_hashes,
+ if not dummy_event_sent:
+ # Did not find a valid user in the room, so remove from future attempts
+ # Exclusion is time limited, so the room will be rechecked in the future
+ # dependent on _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
+ logger.info(
+ "Failed to send dummy event into room %s. Will exclude it from "
+ "future attempts until cache expires" % (room_id,)
+ )
+ now = self.clock.time_msec()
+ self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
+
+ def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
+ expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
+ to_expire = set()
+ for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():
+ if time < expire_before:
+ to_expire.add(room_id)
+ for room_id in to_expire:
+ logger.debug(
+ "Expiring room id %s from dummy event insertion exclusion cache",
+ room_id,
)
-
- event.internal_metadata.proactively_send = False
-
- yield self.send_nonmember_event(requester, event, context, ratelimit=False)
+ del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 053cf66b28..eda15bc623 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,6 +24,7 @@ The methods that define policy are:
import logging
from contextlib import contextmanager
+from typing import Dict, Set
from six import iteritems, itervalues
@@ -179,8 +180,9 @@ class PresenceHandler(object):
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
- self.external_process_to_current_syncs = {}
- self.external_process_last_updated_ms = {}
+ self.external_process_to_current_syncs = {} # type: Dict[int, Set[str]]
+ self.external_process_last_updated_ms = {} # type: Dict[int, int]
+
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
# Start a LoopingCall in 30s that fires every 5s.
@@ -349,10 +351,13 @@ class PresenceHandler(object):
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
+ # For each expired process drop tracking info and check the users
+ # that were syncing on that process to see if they need to be timed
+ # out.
users_to_check.update(
- self.external_process_last_updated_ms.pop(process_id, ())
+ self.external_process_to_current_syncs.pop(process_id, ())
)
- self.external_process_last_update.pop(process_id)
+ self.external_process_last_updated_ms.pop(process_id)
states = [
self.user_to_current_state.get(user_id, UserPresenceState.default(user_id))
@@ -803,17 +808,25 @@ class PresenceHandler(object):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
- deltas = yield self.store.get_current_state_deltas(self._event_pos)
- if not deltas:
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self._event_pos == room_max_stream_ordering:
return
+ logger.debug(
+ "Processing presence stats %s->%s",
+ self._event_pos,
+ room_max_stream_ordering,
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self._event_pos, room_max_stream_ordering
+ )
yield self._handle_state_delta(deltas)
- self._event_pos = deltas[-1]["stream_id"]
+ self._event_pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set(
- self._event_pos
+ max_pos
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 975da57ffd..53410f120b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -217,10 +217,9 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
- attempts = 0
user = None
while not user:
- localpart = yield self._generate_user_id(attempts > 0)
+ localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
@@ -238,7 +237,6 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
user = None
user_id = None
- attempts += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
@@ -275,16 +273,12 @@ class RegistrationHandler(BaseHandler):
fake_requester = create_requester(user_id)
# try to create the room if we're the first real user on the server. Note
- # that an auto-generated support user is not a real user and will never be
+ # that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
- is_support = yield self.store.is_support_user(user_id)
- # There is an edge case where the first user is the support user, then
- # the room is never created, though this seems unlikely and
- # recoverable from given the support user being involved in the first
- # place.
- if self.hs.config.autocreate_auto_join_rooms and not is_support:
- count = yield self.store.count_all_users()
+ is_real_user = yield self.store.is_real_user(user_id)
+ if self.hs.config.autocreate_auto_join_rooms and is_real_user:
+ count = yield self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
@@ -383,10 +377,10 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _generate_user_id(self, reseed=False):
- if reseed or self._next_generated_user_id is None:
+ def _generate_user_id(self):
+ if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
- if reseed or self._next_generated_user_id is None:
+ if self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a509e11d69..2816bd8f87 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -28,6 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -554,7 +555,8 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", [])
for i in invite_list:
try:
- UserID.from_string(i)
+ uid = UserID.from_string(i)
+ parse_and_validate_server_name(uid.domain)
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
@@ -579,8 +581,8 @@ class RoomCreationHandler(BaseHandler):
room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public)
+ directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
- directory_handler = self.hs.get_handlers().directory_handler
yield directory_handler.create_association(
requester=requester,
room_id=room_id,
@@ -665,6 +667,7 @@ class RoomCreationHandler(BaseHandler):
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
+ id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite(
@@ -675,6 +678,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ id_access_token=id_access_token,
)
result = {"room_id": room_id}
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index a7e55f00e5..c615206df1 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -16,8 +16,7 @@
import logging
from collections import namedtuple
-from six import PY3, iteritems
-from six.moves import range
+from six import iteritems
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -27,7 +26,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
-from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
@@ -37,7 +35,6 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
-
# This is used to indicate we should only return rooms published to the main list.
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
@@ -72,6 +69,8 @@ class RoomListHandler(BaseHandler):
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
+ from_federation (bool): true iff the request comes from the federation
+ API
"""
if not self.enable_room_list_search:
return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
@@ -89,16 +88,8 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
- # XXX: Quick hack to stop room directory queries taking too long.
- # Timeout request after 60s. Probably want a more fundamental
- # solution at some point
- timeout = self.clock.time() + 60
return self._get_public_room_list(
- limit,
- since_token,
- search_filter,
- network_tuple=network_tuple,
- timeout=timeout,
+ limit, since_token, search_filter, network_tuple=network_tuple
)
key = (limit, since_token, network_tuple)
@@ -119,7 +110,6 @@ class RoomListHandler(BaseHandler):
search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID,
from_federation=False,
- timeout=None,
):
"""Generate a public room list.
Args:
@@ -132,240 +122,116 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists.
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
- timeout (int|None): Amount of seconds to wait for a response before
- timing out.
"""
- if since_token and since_token != "END":
- since_token = RoomListNextBatch.from_token(since_token)
- else:
- since_token = None
- rooms_to_order_value = {}
- rooms_to_num_joined = {}
+ # Pagination tokens work by storing the room ID sent in the last batch,
+ # plus the direction (forwards or backwards). Next batch tokens always
+ # go forwards, prev batch tokens always go backwards.
- newly_visible = []
- newly_unpublished = []
if since_token:
- stream_token = since_token.stream_ordering
- current_public_id = yield self.store.get_current_public_room_stream_id()
- public_room_stream_id = since_token.public_room_stream_id
- newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
- public_room_stream_id, current_public_id, network_tuple=network_tuple
- )
- else:
- stream_token = yield self.store.get_room_max_stream_ordering()
- public_room_stream_id = yield self.store.get_current_public_room_stream_id()
-
- room_ids = yield self.store.get_public_room_ids_at_stream_id(
- public_room_stream_id, network_tuple=network_tuple
- )
-
- # We want to return rooms in a particular order: the number of joined
- # users. We then arbitrarily use the room_id as a tie breaker.
-
- @defer.inlineCallbacks
- def get_order_for_room(room_id):
- # Most of the rooms won't have changed between the since token and
- # now (especially if the since token is "now"). So, we can ask what
- # the current users are in a room (that will hit a cache) and then
- # check if the room has changed since the since token. (We have to
- # do it in that order to avoid races).
- # If things have changed then fall back to getting the current state
- # at the since token.
- joined_users = yield self.store.get_users_in_room(room_id)
- if self.store.has_room_changed_since(room_id, stream_token):
- latest_event_ids = yield self.store.get_forward_extremeties_for_room(
- room_id, stream_token
- )
-
- if not latest_event_ids:
- return
+ batch_token = RoomListNextBatch.from_token(since_token)
- joined_users = yield self.state_handler.get_current_users_in_room(
- room_id, latest_event_ids
- )
-
- num_joined_users = len(joined_users)
- rooms_to_num_joined[room_id] = num_joined_users
+ bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+ forwards = batch_token.direction_is_forward
+ else:
+ batch_token = None
+ bounds = None
- if num_joined_users == 0:
- return
+ forwards = True
- # We want larger rooms to be first, hence negating num_joined_users
- rooms_to_order_value[room_id] = (-num_joined_users, room_id)
+ # we request one more than wanted to see if there are more pages to come
+ probing_limit = limit + 1 if limit is not None else None
- logger.info(
- "Getting ordering for %i rooms since %s", len(room_ids), stream_token
+ results = yield self.store.get_largest_public_rooms(
+ network_tuple,
+ search_filter,
+ probing_limit,
+ bounds=bounds,
+ forwards=forwards,
+ ignore_non_federatable=from_federation,
)
- yield concurrently_execute(get_order_for_room, room_ids, 10)
- sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
- sorted_rooms = [room_id for room_id, _ in sorted_entries]
+ def build_room_entry(room):
+ entry = {
+ "room_id": room["room_id"],
+ "name": room["name"],
+ "topic": room["topic"],
+ "canonical_alias": room["canonical_alias"],
+ "num_joined_members": room["joined_members"],
+ "avatar_url": room["avatar"],
+ "world_readable": room["history_visibility"] == "world_readable",
+ "guest_can_join": room["guest_access"] == "can_join",
+ }
- # `sorted_rooms` should now be a list of all public room ids that is
- # stable across pagination. Therefore, we can use indices into this
- # list as our pagination tokens.
+ # Filter out Nones – rather omit the field altogether
+ return {k: v for k, v in entry.items() if v is not None}
- # Filter out rooms that we don't want to return
- rooms_to_scan = [
- r
- for r in sorted_rooms
- if r not in newly_unpublished and rooms_to_num_joined[r] > 0
- ]
+ results = [build_room_entry(r) for r in results]
- total_room_count = len(rooms_to_scan)
+ response = {}
+ num_results = len(results)
+ if limit is not None:
+ more_to_come = num_results == probing_limit
- if since_token:
- # Filter out rooms we've already returned previously
- # `since_token.current_limit` is the index of the last room we
- # sent down, so we exclude it and everything before/after it.
- if since_token.direction_is_forward:
- rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :]
+ # Depending on direction we trim either the front or back.
+ if forwards:
+ results = results[:limit]
else:
- rooms_to_scan = rooms_to_scan[: since_token.current_limit]
- rooms_to_scan.reverse()
-
- logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan))
-
- # _append_room_entry_to_chunk will append to chunk but will stop if
- # len(chunk) > limit
- #
- # Normally we will generate enough results on the first iteration here,
- # but if there is a search filter, _append_room_entry_to_chunk may
- # filter some results out, in which case we loop again.
- #
- # We don't want to scan over the entire range either as that
- # would potentially waste a lot of work.
- #
- # XXX if there is no limit, we may end up DoSing the server with
- # calls to get_current_state_ids for every single room on the
- # server. Surely we should cap this somehow?
- #
- if limit:
- step = limit + 1
+ results = results[-limit:]
else:
- # step cannot be zero
- step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
-
- chunk = []
- for i in range(0, len(rooms_to_scan), step):
- if timeout and self.clock.time() > timeout:
- raise Exception("Timed out searching room directory")
-
- batch = rooms_to_scan[i : i + step]
- logger.info("Processing %i rooms for result", len(batch))
- yield concurrently_execute(
- lambda r: self._append_room_entry_to_chunk(
- r,
- rooms_to_num_joined[r],
- chunk,
- limit,
- search_filter,
- from_federation=from_federation,
- ),
- batch,
- 5,
- )
- logger.info("Now %i rooms in result", len(chunk))
- if len(chunk) >= limit + 1:
- break
-
- chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
-
- # Work out the new limit of the batch for pagination, or None if we
- # know there are no more results that would be returned.
- # i.e., [since_token.current_limit..new_limit] is the batch of rooms
- # we've returned (or the reverse if we paginated backwards)
- # We tried to pull out limit + 1 rooms above, so if we have <= limit
- # then we know there are no more results to return
- new_limit = None
- if chunk and (not limit or len(chunk) > limit):
-
- if not since_token or since_token.direction_is_forward:
- if limit:
- chunk = chunk[:limit]
- last_room_id = chunk[-1]["room_id"]
+ more_to_come = False
+
+ if num_results > 0:
+ final_entry = results[-1]
+ initial_entry = results[0]
+
+ if forwards:
+ if batch_token:
+ # If there was a token given then we assume that there
+ # must be previous results.
+ response["prev_batch"] = RoomListNextBatch(
+ last_joined_members=initial_entry["num_joined_members"],
+ last_room_id=initial_entry["room_id"],
+ direction_is_forward=False,
+ ).to_token()
+
+ if more_to_come:
+ response["next_batch"] = RoomListNextBatch(
+ last_joined_members=final_entry["num_joined_members"],
+ last_room_id=final_entry["room_id"],
+ direction_is_forward=True,
+ ).to_token()
else:
- if limit:
- chunk = chunk[-limit:]
- last_room_id = chunk[0]["room_id"]
-
- new_limit = sorted_rooms.index(last_room_id)
-
- results = {"chunk": chunk, "total_room_count_estimate": total_room_count}
-
- if since_token:
- results["new_rooms"] = bool(newly_visible)
-
- if not since_token or since_token.direction_is_forward:
- if new_limit is not None:
- results["next_batch"] = RoomListNextBatch(
- stream_ordering=stream_token,
- public_room_stream_id=public_room_stream_id,
- current_limit=new_limit,
- direction_is_forward=True,
- ).to_token()
-
- if since_token:
- results["prev_batch"] = since_token.copy_and_replace(
- direction_is_forward=False,
- current_limit=since_token.current_limit + 1,
- ).to_token()
- else:
- if new_limit is not None:
- results["prev_batch"] = RoomListNextBatch(
- stream_ordering=stream_token,
- public_room_stream_id=public_room_stream_id,
- current_limit=new_limit,
- direction_is_forward=False,
- ).to_token()
-
- if since_token:
- results["next_batch"] = since_token.copy_and_replace(
- direction_is_forward=True,
- current_limit=since_token.current_limit - 1,
- ).to_token()
-
- return results
-
- @defer.inlineCallbacks
- def _append_room_entry_to_chunk(
- self,
- room_id,
- num_joined_users,
- chunk,
- limit,
- search_filter,
- from_federation=False,
- ):
- """Generate the entry for a room in the public room list and append it
- to the `chunk` if it matches the search filter
-
- Args:
- room_id (str): The ID of the room.
- num_joined_users (int): The number of joined users in the room.
- chunk (list)
- limit (int|None): Maximum amount of rooms to display. Function will
- return if length of chunk is greater than limit + 1.
- search_filter (dict|None)
- from_federation (bool): Whether this request originated from a
- federating server or a client. Used for room filtering.
- """
- if limit and len(chunk) > limit + 1:
- # We've already got enough, so lets just drop it.
- return
+ if batch_token:
+ response["next_batch"] = RoomListNextBatch(
+ last_joined_members=final_entry["num_joined_members"],
+ last_room_id=final_entry["room_id"],
+ direction_is_forward=True,
+ ).to_token()
+
+ if more_to_come:
+ response["prev_batch"] = RoomListNextBatch(
+ last_joined_members=initial_entry["num_joined_members"],
+ last_room_id=initial_entry["room_id"],
+ direction_is_forward=False,
+ ).to_token()
+
+ for room in results:
+ # populate search result entries with additional fields, namely
+ # 'aliases'
+ room_id = room["room_id"]
+
+ aliases = yield self.store.get_aliases_for_room(room_id)
+ if aliases:
+ room["aliases"] = aliases
- result = yield self.generate_room_entry(room_id, num_joined_users)
- if not result:
- return
+ response["chunk"] = results
- if from_federation and not result.get("m.federate", True):
- # This is a room that other servers cannot join. Do not show them
- # this room.
- return
+ response["total_room_count_estimate"] = yield self.store.count_public_rooms(
+ network_tuple, ignore_non_federatable=from_federation
+ )
- if _matches_room_entry(result, search_filter):
- chunk.append(result)
+ return response
@cachedInlineCallbacks(num_args=1, cache_context=True)
def generate_room_entry(
@@ -580,18 +446,15 @@ class RoomListNextBatch(
namedtuple(
"RoomListNextBatch",
(
- "stream_ordering", # stream_ordering of the first public room list
- "public_room_stream_id", # public room stream id for first public room list
- "current_limit", # The number of previous rooms returned
+ "last_joined_members", # The count to get rooms after/before
+ "last_room_id", # The room_id to get rooms after/before
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
),
)
):
-
KEY_DICT = {
- "stream_ordering": "s",
- "public_room_stream_id": "p",
- "current_limit": "n",
+ "last_joined_members": "m",
+ "last_room_id": "r",
"direction_is_forward": "d",
}
@@ -599,13 +462,7 @@ class RoomListNextBatch(
@classmethod
def from_token(cls, token):
- if PY3:
- # The argument raw=False is only available on new versions of
- # msgpack, and only really needed on Python 3. Gate it behind
- # a PY3 check to avoid causing issues on Debian-packaged versions.
- decoded = msgpack.loads(decode_base64(token), raw=False)
- else:
- decoded = msgpack.loads(decode_base64(token))
+ decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 093f2ea36e..380e2fad5e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -20,15 +20,11 @@ import logging
from six.moves import http_client
-from signedjson.key import decode_verify_key_bytes
-from signedjson.sign import verify_signed_json
-from unpaddedbase64 import decode_base64
-
from twisted.internet import defer
from synapse import types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -37,8 +33,6 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of
@@ -59,10 +53,10 @@ class RoomMemberHandler(object):
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
- self.simple_http_client = hs.get_simple_http_client()
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
+ self.identity_handler = hs.get_handlers().identity_handler
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -100,7 +94,7 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
@@ -209,23 +203,11 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield self._user_joined_room(target, room_id)
-
- # Copy over direct message status and room tags if this is a join
- # on an upgraded room
-
- # Check if this is an upgraded room
- predecessor = yield self.store.get_room_predecessor(room_id)
-
- if predecessor:
- # It is an upgraded room. Copy over old tags
- self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], room_id, user_id
- )
- # Move over old push rules
- self.store.move_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], room_id, user_id
+ # Copy over user state if we're joining an upgraded room
+ yield self.copy_user_state_if_room_upgrade(
+ room_id, requester.user.to_string()
)
+ yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@@ -469,10 +451,16 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- ret = yield self._remote_join(
+ remote_join_response = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
- return ret
+
+ # Copy over user state if this is a join on an remote upgraded room
+ yield self.copy_user_state_if_room_upgrade(
+ room_id, requester.user.to_string()
+ )
+
+ return remote_join_response
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@@ -510,9 +498,39 @@ class RoomMemberHandler(object):
return res
@defer.inlineCallbacks
- def send_membership_event(
- self, requester, event, context, remote_room_hosts=None, ratelimit=True
- ):
+ def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
+ """Copy user-specific information when they join a new room if that new room is the
+ result of a room upgrade
+
+ Args:
+ new_room_id (str): The ID of the room the user is joining
+ user_id (str): The ID of the user
+
+ Returns:
+ Deferred
+ """
+ # Check if the new room is an upgraded room
+ predecessor = yield self.store.get_room_predecessor(new_room_id)
+ if not predecessor:
+ return
+
+ logger.debug(
+ "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+ new_room_id,
+ predecessor,
+ )
+
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ predecessor["room_id"], new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ predecessor["room_id"], new_room_id, user_id
+ )
+
+ @defer.inlineCallbacks
+ def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@@ -522,16 +540,10 @@ class RoomMemberHandler(object):
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
- is_guest (bool): Whether the sender is a guest.
- room_hosts ([str]): Homeservers which are likely to already be in
- the room, and could be danced with in order to join this
- homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
- remote_room_hosts = remote_room_hosts or []
-
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
@@ -634,7 +646,7 @@ class RoomMemberHandler(object):
servers.remove(room_alias.domain)
servers.insert(0, room_alias.domain)
- return (RoomID.from_string(room_id), servers)
+ return RoomID.from_string(room_id), servers
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
@@ -646,7 +658,15 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def do_3pid_invite(
- self, room_id, inviter, medium, address, id_server, requester, txn_id
+ self,
+ room_id,
+ inviter,
+ medium,
+ address,
+ id_server,
+ requester,
+ txn_id,
+ id_access_token=None,
):
if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
@@ -669,75 +689,42 @@ class RoomMemberHandler(object):
Codes.FORBIDDEN,
)
- invitee = yield self._lookup_3pid(id_server, medium, address)
-
- if invitee:
- yield self.update_membership(
- requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
- )
- else:
- yield self._make_and_store_3pid_invite(
- requester, id_server, medium, address, room_id, inviter, txn_id=txn_id
- )
-
- @defer.inlineCallbacks
- def _lookup_3pid(self, id_server, medium, address):
- """Looks up a 3pid in the passed identity server.
-
- Args:
- id_server (str): The server name (including port, if required)
- of the identity server to use.
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
-
- Returns:
- str: the matrix ID of the 3pid, or None if it is not recognized.
- """
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
)
- try:
- data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
- {"medium": medium, "address": address},
- )
-
- if "mxid" in data:
- if "signatures" not in data:
- raise AuthError(401, "No signatures on 3pid binding")
- yield self._verify_any_signature(data, id_server)
- return data["mxid"]
- except IOError as e:
- logger.warn("Error from identity server lookup: %s" % (e,))
- return None
+ invitee = yield self.identity_handler.lookup_3pid(
+ id_server, medium, address, id_access_token
+ )
- @defer.inlineCallbacks
- def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
- for key_name, signature in data["signatures"][server_hostname].items():
- key_data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
+ if invitee:
+ yield self.update_membership(
+ requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
- if "public_key" not in key_data:
- raise AuthError(
- 401, "No public key named %s from %s" % (key_name, server_hostname)
- )
- verify_signed_json(
- data,
- server_hostname,
- decode_verify_key_bytes(
- key_name, decode_base64(key_data["public_key"])
- ),
+ else:
+ yield self._make_and_store_3pid_invite(
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ inviter,
+ txn_id=txn_id,
+ id_access_token=id_access_token,
)
- return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
- self, requester, id_server, medium, address, room_id, user, txn_id
+ self,
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ user,
+ txn_id,
+ id_access_token=None,
):
room_state = yield self.state_handler.get_current_state(room_id)
@@ -773,7 +760,7 @@ class RoomMemberHandler(object):
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
- yield self._ask_id_server_for_third_party_invite(
+ yield self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@@ -786,6 +773,7 @@ class RoomMemberHandler(object):
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url,
+ id_access_token=id_access_token,
)
)
@@ -809,103 +797,6 @@ class RoomMemberHandler(object):
)
@defer.inlineCallbacks
- def _ask_id_server_for_third_party_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter_user_id,
- room_alias,
- room_avatar_url,
- room_join_rules,
- room_name,
- inviter_display_name,
- inviter_avatar_url,
- ):
- """
- Asks an identity server for a third party invite.
-
- Args:
- requester (Requester)
- id_server (str): hostname + optional port for the identity server.
- medium (str): The literal string "email".
- address (str): The third party address being invited.
- room_id (str): The ID of the room to which the user is invited.
- inviter_user_id (str): The user ID of the inviter.
- room_alias (str): An alias for the room, for cosmetic notifications.
- room_avatar_url (str): The URL of the room's avatar, for cosmetic
- notifications.
- room_join_rules (str): The join rules of the email (e.g. "public").
- room_name (str): The m.room.name of the room.
- inviter_display_name (str): The current display name of the
- inviter.
- inviter_avatar_url (str): The URL of the inviter's avatar.
-
- Returns:
- A deferred tuple containing:
- token (str): The token which must be signed to prove authenticity.
- public_keys ([{"public_key": str, "key_validity_url": str}]):
- public_key is a base64-encoded ed25519 public key.
- fallback_public_key: One element from public_keys.
- display_name (str): A user-friendly name to represent the invited
- user.
- """
-
- is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
- id_server_scheme,
- id_server,
- )
-
- invite_config = {
- "medium": medium,
- "address": address,
- "room_id": room_id,
- "room_alias": room_alias,
- "room_avatar_url": room_avatar_url,
- "room_join_rules": room_join_rules,
- "room_name": room_name,
- "sender": inviter_user_id,
- "sender_display_name": inviter_display_name,
- "sender_avatar_url": inviter_avatar_url,
- }
-
- try:
- data = yield self.simple_http_client.post_json_get_json(
- is_url, invite_config
- )
- except HttpResponseException as e:
- # Some identity servers may only support application/x-www-form-urlencoded
- # types. This is especially true with old instances of Sydent, see
- # https://github.com/matrix-org/sydent/pull/170
- logger.info(
- "Failed to POST %s with JSON, falling back to urlencoded form: %s",
- is_url,
- e,
- )
- data = yield self.simple_http_client.post_urlencoded_get_json(
- is_url, invite_config
- )
-
- # TODO: Check for success
- token = data["token"]
- public_keys = data.get("public_keys", [])
- if "public_key" in data:
- fallback_public_key = {
- "public_key": data["public_key"],
- "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid"
- % (id_server_scheme, id_server),
- }
- else:
- fallback_public_key = public_keys[0]
-
- if not public_keys:
- public_keys.append(fallback_public_key)
- display_name = data["display_name"]
- return token, public_keys, fallback_public_key, display_name
-
- @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
@@ -1057,7 +948,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
+ logger.warning("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(target.to_string(), room_id)
return {}
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a1ce6929cf..cc9e6b9bd0 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
+from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -29,12 +31,26 @@ class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs)
+ self._registration_handler = hs.get_registration_handler()
+
+ self._clock = hs.get_clock()
+ self._datastore = hs.get_datastore()
+ self._hostname = hs.hostname
+ self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
+ self._grandfathered_mxid_source_attribute = (
+ hs.config.saml2_grandfathered_mxid_source_attribute
+ )
+ self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # identifier for the external_ids table
+ self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {}
- self._clock = hs.get_clock()
- self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect
@@ -60,7 +76,7 @@ class SamlHandler:
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")
- def handle_saml_response(self, request):
+ async def handle_saml_response(self, request):
"""Handle an incoming request to /_matrix/saml2/authn_response
Args:
@@ -77,6 +93,10 @@ class SamlHandler:
# the dict.
self.expire_sessions()
+ user_id = await self._map_saml_response_to_user(resp_bytes)
+ self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
+
+ async def _map_saml_response_to_user(self, resp_bytes):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -91,18 +111,88 @@ class SamlHandler:
logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")
- if "uid" not in saml2_auth.ava:
+ logger.info("SAML2 response: %s", saml2_auth.origxml)
+ logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+
+ try:
+ remote_user_id = saml2_auth.ava["uid"][0]
+ except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "uid not in SAML2 response")
+ try:
+ mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- username = saml2_auth.ava["uid"][0]
displayName = saml2_auth.ava.get("displayName", [None])[0]
- return self._sso_auth_handler.on_successful_auth(
- username, request, relay_state, user_display_name=displayName
- )
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
+ # first of all, check if we already have a mapping for this user
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ self._auth_provider_id,
+ remote_user_id,
+ )
+ registered_user_id = await self._datastore.get_user_by_external_id(
+ self._auth_provider_id, remote_user_id
+ )
+ if registered_user_id is not None:
+ logger.info("Found existing mapping %s", registered_user_id)
+ return registered_user_id
+
+ # backwards-compatibility hack: see if there is an existing user with a
+ # suitable mapping from the uid
+ if (
+ self._grandfathered_mxid_source_attribute
+ and self._grandfathered_mxid_source_attribute in saml2_auth.ava
+ ):
+ attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
+ user_id = UserID(
+ map_username_to_mxid_localpart(attrval), self._hostname
+ ).to_string()
+ logger.info(
+ "Looking for existing account based on mapped %s %s",
+ self._grandfathered_mxid_source_attribute,
+ user_id,
+ )
+
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
+
+ # figure out a new mxid for this user
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ suffix = 0
+ while True:
+ localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ if not await self._datastore.get_users_by_id_case_insensitive(
+ UserID(localpart, self._hostname).to_string()
+ ):
+ break
+ suffix += 1
+ logger.info("Allocating mxid for new user with localpart %s", localpart)
+
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=displayName
+ )
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 3c265f3718..466daf9202 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -84,17 +84,26 @@ class StatsHandler(StateDeltasHandler):
# Loop round handling deltas until we're up to date
while True:
- deltas = yield self.store.get_current_state_deltas(self.pos)
+ # Be sure to read the max stream_ordering *before* checking if there are any outstanding
+ # deltas, since there is otherwise a chance that we could miss updates which arrive
+ # after we check the deltas.
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos == room_max_stream_ordering:
+ break
+
+ logger.debug(
+ "Processing room stats %s->%s", self.pos, room_max_stream_ordering
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self.pos, room_max_stream_ordering
+ )
if deltas:
logger.debug("Handling %d state deltas", len(deltas))
room_deltas, user_deltas = yield self._handle_deltas(deltas)
-
- max_pos = deltas[-1]["stream_id"]
else:
room_deltas = {}
user_deltas = {}
- max_pos = yield self.store.get_room_max_stream_ordering()
# Then count deltas for total_events and total_event_bytes.
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
@@ -117,10 +126,9 @@ class StatsHandler(StateDeltasHandler):
stream_id=max_pos,
)
- event_processing_positions.labels("stats").set(max_pos)
+ logger.debug("Handled room stats to %s -> %s", self.pos, max_pos)
- if self.pos == max_pos:
- break
+ event_processing_positions.labels("stats").set(max_pos)
self.pos = max_pos
@@ -287,6 +295,7 @@ class StatsHandler(StateDeltasHandler):
room_state["guest_access"] = event_content.get("guest_access")
for room_id, state in room_to_state_updates.items():
+ logger.info("Updating room_stats_state for %s: %s", room_id, state)
yield self.store.update_room_state(room_id, state)
return room_to_stats_deltas, user_to_stats_deltas
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
new file mode 100644
index 0000000000..824f37f8f8
--- /dev/null
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module implements user-interactive auth verification.
+
+TODO: move more stuff out of AuthHandler in here.
+
+"""
+
+from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
new file mode 100644
index 0000000000..29aa1e5aaf
--- /dev/null
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+from twisted.web.client import PartialDownloadError
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.config.emailconfig import ThreepidBehaviour
+
+logger = logging.getLogger(__name__)
+
+
+class UserInteractiveAuthChecker:
+ """Abstract base class for an interactive auth checker"""
+
+ def __init__(self, hs):
+ pass
+
+ def is_enabled(self):
+ """Check if the configuration of the homeserver allows this checker to work
+
+ Returns:
+ bool: True if this login type is enabled.
+ """
+
+ def check_auth(self, authdict, clientip):
+ """Given the authentication dict from the client, attempt to check this step
+
+ Args:
+ authdict (dict): authentication dictionary from the client
+ clientip (str): The IP address of the client.
+
+ Raises:
+ SynapseError if authentication failed
+
+ Returns:
+ Deferred: the result of authentication (to pass back to the client?)
+ """
+ raise NotImplementedError()
+
+
+class DummyAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.DUMMY
+
+ def is_enabled(self):
+ return True
+
+ def check_auth(self, authdict, clientip):
+ return defer.succeed(True)
+
+
+class TermsAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.TERMS
+
+ def is_enabled(self):
+ return True
+
+ def check_auth(self, authdict, clientip):
+ return defer.succeed(True)
+
+
+class RecaptchaAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.RECAPTCHA
+
+ def __init__(self, hs):
+ super().__init__(hs)
+ self._enabled = bool(hs.config.recaptcha_private_key)
+ self._http_client = hs.get_simple_http_client()
+ self._url = hs.config.recaptcha_siteverify_api
+ self._secret = hs.config.recaptcha_private_key
+
+ def is_enabled(self):
+ return self._enabled
+
+ @defer.inlineCallbacks
+ def check_auth(self, authdict, clientip):
+ try:
+ user_response = authdict["response"]
+ except KeyError:
+ # Client tried to provide captcha but didn't give the parameter:
+ # bad request.
+ raise LoginError(
+ 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
+ )
+
+ logger.info(
+ "Submitting recaptcha response %s with remoteip %s", user_response, clientip
+ )
+
+ # TODO: get this from the homeserver rather than creating a new one for
+ # each request
+ try:
+ resp_body = yield self._http_client.post_urlencoded_get_json(
+ self._url,
+ args={
+ "secret": self._secret,
+ "response": user_response,
+ "remoteip": clientip,
+ },
+ )
+ except PartialDownloadError as pde:
+ # Twisted is silly
+ data = pde.response
+ resp_body = json.loads(data)
+
+ if "success" in resp_body:
+ # Note that we do NOT check the hostname here: we explicitly
+ # intend the CAPTCHA to be presented by whatever client the
+ # user is using, we just care that they have completed a CAPTCHA.
+ logger.info(
+ "%s reCAPTCHA from hostname %s",
+ "Successful" if resp_body["success"] else "Failed",
+ resp_body.get("hostname"),
+ )
+ if resp_body["success"]:
+ return True
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+
+class _BaseThreepidAuthChecker:
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def _check_threepid(self, medium, authdict):
+ if "threepid_creds" not in authdict:
+ raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
+
+ threepid_creds = authdict["threepid_creds"]
+
+ identity_handler = self.hs.get_handlers().identity_handler
+
+ logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
+
+ # msisdns are currently always ThreepidBehaviour.REMOTE
+ if medium == "msisdn":
+ if not self.hs.config.account_threepid_delegate_msisdn:
+ raise SynapseError(
+ 400, "Phone number verification is not enabled on this homeserver"
+ )
+ threepid = yield identity_handler.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ )
+ elif medium == "email":
+ if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
+ threepid = yield identity_handler.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_email, threepid_creds
+ )
+ elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ threepid = None
+ row = yield self.store.get_threepid_validation_session(
+ medium,
+ threepid_creds["client_secret"],
+ sid=threepid_creds["sid"],
+ validated=True,
+ )
+
+ if row:
+ threepid = {
+ "medium": row["medium"],
+ "address": row["address"],
+ "validated_at": row["validated_at"],
+ }
+
+ # Valid threepid returned, delete from the db
+ yield self.store.delete_threepid_session(threepid_creds["sid"])
+ else:
+ raise SynapseError(
+ 400, "Email address verification is not enabled on this homeserver"
+ )
+ else:
+ # this can't happen!
+ raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
+
+ if not threepid:
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+ if threepid["medium"] != medium:
+ raise LoginError(
+ 401,
+ "Expecting threepid of type '%s', got '%s'"
+ % (medium, threepid["medium"]),
+ errcode=Codes.UNAUTHORIZED,
+ )
+
+ threepid["threepid_creds"] = authdict["threepid_creds"]
+
+ return threepid
+
+
+class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
+ AUTH_TYPE = LoginType.EMAIL_IDENTITY
+
+ def __init__(self, hs):
+ UserInteractiveAuthChecker.__init__(self, hs)
+ _BaseThreepidAuthChecker.__init__(self, hs)
+
+ def is_enabled(self):
+ return self.hs.config.threepid_behaviour_email in (
+ ThreepidBehaviour.REMOTE,
+ ThreepidBehaviour.LOCAL,
+ )
+
+ def check_auth(self, authdict, clientip):
+ return self._check_threepid("email", authdict)
+
+
+class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
+ AUTH_TYPE = LoginType.MSISDN
+
+ def __init__(self, hs):
+ UserInteractiveAuthChecker.__init__(self, hs)
+ _BaseThreepidAuthChecker.__init__(self, hs)
+
+ def is_enabled(self):
+ return bool(self.hs.config.account_threepid_delegate_msisdn)
+
+ def check_auth(self, authdict, clientip):
+ return self._check_threepid("msisdn", authdict)
+
+
+INTERACTIVE_AUTH_CHECKERS = [
+ DummyAuthChecker,
+ TermsAuthChecker,
+ RecaptchaAuthChecker,
+ EmailIdentityAuthChecker,
+ MsisdnAuthChecker,
+]
+"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index e53669e40d..624f05ab5b 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -138,21 +138,28 @@ class UserDirectoryHandler(StateDeltasHandler):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
- deltas = yield self.store.get_current_state_deltas(self.pos)
- if not deltas:
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos == room_max_stream_ordering:
return
+ logger.debug(
+ "Processing user stats %s->%s", self.pos, room_max_stream_ordering
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self.pos, room_max_stream_ordering
+ )
+
logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
- self.pos = deltas[-1]["stream_id"]
+ self.pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set(
- self.pos
+ max_pos
)
- yield self.store.update_user_directory_stream_pos(self.pos)
+ yield self.store.update_user_directory_stream_pos(max_pos)
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 3acf772cd1..3880ce0d94 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -42,11 +42,13 @@ def cancelled_to_request_timed_out_error(value, timeout):
ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
+CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
def redact_uri(uri):
- """Strips access tokens from the uri replaces with <redacted>"""
- return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
+ """Strips sensitive information from the uri replaces with <redacted>"""
+ uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
+ return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
class QuieterFileBodyProducer(FileBodyProducer):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 51765ae3c0..cdf828a4ff 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -327,7 +327,7 @@ class SimpleHttpClient(object):
Args:
uri (str):
args (dict[str, str|List[str]]): query params
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
@@ -371,7 +371,7 @@ class SimpleHttpClient(object):
Args:
uri (str):
post_json (object):
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
@@ -414,7 +414,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@@ -438,7 +438,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@@ -482,7 +482,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@@ -516,7 +516,7 @@ class SimpleHttpClient(object):
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index feae7de5be..647d26dc56 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -217,7 +217,7 @@ class MatrixHostnameEndpoint(object):
self._tls_options = None
else:
self._tls_options = tls_client_options_factory.get_options(
- self._parsed_uri.host.decode("ascii")
+ self._parsed_uri.host
)
self._srv_resolver = srv_resolver
diff --git a/synapse/http/server.py b/synapse/http/server.py
index cb9158fe1b..2ccb210fd6 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -388,7 +388,7 @@ class DirectServeResource(resource.Resource):
if not callback:
return super().render(request)
- resp = callback(request)
+ resp = trace_servlet(self.__class__.__name__)(callback)(request)
# If it's a coroutine, turn it into a Deferred
if isinstance(resp, types.CoroutineType):
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 0367d6dfc4..3220e985a9 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -18,6 +18,7 @@ import os.path
import sys
import typing
import warnings
+from typing import List
import attr
from constantly import NamedConstant, Names, ValueConstant, Values
@@ -33,7 +34,6 @@ from twisted.logger import (
LogLevelFilterPredicate,
LogPublisher,
eventAsText,
- globalLogBeginner,
jsonFileLogObserver,
)
@@ -134,7 +134,7 @@ class PythonStdlibToTwistedLogger(logging.Handler):
)
-def SynapseFileLogObserver(outFile: typing.io.TextIO) -> FileLogObserver:
+def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
"""
A log observer that formats events like the traditional log formatter and
sends them to `outFile`.
@@ -265,7 +265,7 @@ def setup_structured_logging(
hs,
config,
log_config: dict,
- logBeginner: LogBeginner = globalLogBeginner,
+ logBeginner: LogBeginner,
redirect_stdlib_logging: bool = True,
) -> LogPublisher:
"""
@@ -286,7 +286,7 @@ def setup_structured_logging(
if "drains" not in log_config:
raise ConfigError("The logging configuration requires a list of drains.")
- observers = []
+ observers = [] # type: List[ILogObserver]
for observer in parse_drain_configs(log_config["drains"]):
# Pipe drains
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 7f1e8f23fe..0ebbde06f2 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -21,10 +21,11 @@ import sys
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
-from typing.io import TextIO
+from typing import IO
import attr
from simplejson import dumps
+from zope.interface import implementer
from twisted.application.internet import ClientService
from twisted.internet.endpoints import (
@@ -33,7 +34,7 @@ from twisted.internet.endpoints import (
TCP6ClientEndpoint,
)
from twisted.internet.protocol import Factory, Protocol
-from twisted.logger import FileLogObserver, Logger
+from twisted.logger import FileLogObserver, ILogObserver, Logger
from twisted.python.failure import Failure
@@ -129,7 +130,7 @@ def flatten_event(event: dict, metadata: dict, include_time: bool = False):
return new_event
-def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObserver:
+def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
"""
A log observer that formats events to a flattened JSON representation.
@@ -146,6 +147,7 @@ def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObs
@attr.s
+@implementer(ILogObserver)
class TerseJSONToTCPLogObserver(object):
"""
An IObserver that writes JSON logs to a TCP target.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 63379bfb93..370000e377 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -42,13 +43,17 @@ try:
# exception.
resource.getrusage(RUSAGE_THREAD)
+ is_thread_resource_usage_supported = True
+
def get_thread_resource_usage():
return resource.getrusage(RUSAGE_THREAD)
except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
- # won't track resource usage by returning None.
+ # won't track resource usage.
+ is_thread_resource_usage_supported = False
+
def get_thread_resource_usage():
return None
@@ -359,7 +364,11 @@ class LoggingContext(object):
# When we stop, let's record the cpu used since we started
if not self.usage_start:
- logger.warning("Called stop on logcontext %s without calling start", self)
+ # Log a warning on platforms that support thread usage tracking
+ if is_thread_resource_usage_supported:
+ logger.warning(
+ "Called stop on logcontext %s without calling start", self
+ )
return
utime_delta, stime_delta = self._get_cputime()
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 7246253018..0638cec429 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,9 @@ import contextlib
import inspect
import logging
import re
+import types
from functools import wraps
+from typing import Dict
from canonicaljson import json
@@ -223,8 +225,8 @@ try:
from jaeger_client import Config as JaegerConfig
from synapse.logging.scopecontextmanager import LogContextScopeManager
except ImportError:
- JaegerConfig = None
- LogContextScopeManager = None
+ JaegerConfig = None # type: ignore
+ LogContextScopeManager = None # type: ignore
logger = logging.getLogger(__name__)
@@ -547,7 +549,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -584,7 +586,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -639,7 +641,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@@ -653,7 +655,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
- carrier = {}
+ carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
@@ -777,8 +779,7 @@ def trace_servlet(servlet_name, extract_context=False):
return func
@wraps(func)
- @defer.inlineCallbacks
- def _trace_servlet_inner(request, *args, **kwargs):
+ async def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@@ -795,8 +796,14 @@ def trace_servlet(servlet_name, extract_context=False):
scope = start_active_span(servlet_name, tags=request_tags)
with scope:
- result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- return result
+ result = func(request, *args, **kwargs)
+
+ if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+ # Some servlets aren't async and just return results
+ # directly, so we handle that here.
+ return result
+
+ return await result
return _trace_servlet_inner
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 7df0fa6087..6073fc2725 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG
- s = inspect.currentframe().f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
@@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
- args=None,
+ args=tuple(),
exc_info=None,
)
@@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames():
- s = inspect.currentframe().f_back.f_back
+
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
@@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]):
- s = inspect.currentframe().f_back.f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+ s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 488280b4a6..0b45e1f52a 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,6 +20,7 @@ import os
import platform
import threading
import time
+from typing import Dict, Union
import six
@@ -29,20 +30,20 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricF
from twisted.internet import reactor
+import synapse
from synapse.metrics._exposition import (
MetricsResource,
generate_latest,
start_http_server,
)
+from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
running_on_pypy = platform.python_implementation() == "PyPy"
-all_metrics = []
-all_collectors = []
-all_gauges = {}
+all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]]
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@@ -124,7 +125,7 @@ class InFlightGauge(object):
)
# Counts number of in flight blocks for a given set of label values
- self._registrations = {}
+ self._registrations = {} # type: Dict
# Protects access to _registrations
self._lock = threading.Lock()
@@ -225,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous!
data = self.data_collector()
- buckets = {}
+ buckets = {} # type: Dict[float, int]
res = []
for x in data.keys():
@@ -385,6 +386,16 @@ event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"
# finished being processed.
event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
+# Build info of the running server.
+build_info = Gauge(
+ "synapse_build_info", "Build information", ["pythonversion", "version", "osversion"]
+)
+build_info.labels(
+ " ".join([platform.python_implementation(), platform.python_version()]),
+ get_version_string(synapse),
+ " ".join([platform.system(), platform.release()]),
+).set(1)
+
last_ticked = time.time()
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 1933ecd3e3..a248103191 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -36,7 +36,9 @@ from twisted.web.resource import Resource
try:
from prometheus_client.samples import Sample
except ImportError:
- Sample = namedtuple("Sample", ["name", "labels", "value", "timestamp", "exemplar"])
+ Sample = namedtuple( # type: ignore[no-redef] # noqa
+ "Sample", ["name", "labels", "value", "timestamp", "exemplar"]
+ )
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index edd6b42db3..c53d2a0d40 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,6 +15,8 @@
import logging
import threading
+from asyncio import iscoroutine
+from functools import wraps
import six
@@ -173,7 +175,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
Args:
desc (str): a description for this background process type
- func: a function, which may return a Deferred
+ func: a function, which may return a Deferred or a coroutine
args: positional args for func
kwargs: keyword args for func
@@ -197,7 +199,17 @@ def run_as_background_process(desc, func, *args, **kwargs):
_background_processes.setdefault(desc, set()).add(proc)
try:
- yield func(*args, **kwargs)
+ result = func(*args, **kwargs)
+
+ # We probably don't have an ensureDeferred in our call stack to handle
+ # coroutine results, so we need to ensureDeferred here.
+ #
+ # But we need this check because ensureDeferred doesn't like being
+ # called on immediate values (as opposed to Deferreds or coroutines).
+ if iscoroutine(result):
+ result = defer.ensureDeferred(result)
+
+ return (yield result)
except Exception:
logger.exception("Background process '%s' threw an exception", desc)
finally:
@@ -208,3 +220,20 @@ def run_as_background_process(desc, func, *args, **kwargs):
with PreserveLoggingContext():
return run()
+
+
+def wrap_as_background_process(desc):
+ """Decorator that wraps a function that gets called as a background
+ process.
+
+ Equivalent of calling the function with `run_as_background_process`
+ """
+
+ def wrap_as_background_process_inner(func):
+ @wraps(func)
+ def wrap_as_background_process_inner_2(*args, **kwargs):
+ return run_as_background_process(desc, func, *args, **kwargs)
+
+ return wrap_as_background_process_inner_2
+
+ return wrap_as_background_process_inner
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index bd5d53af91..6299587808 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -22,6 +22,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
@@ -194,7 +195,17 @@ class HttpPusher(object):
)
for push_action in unprocessed:
- processed = yield self._process_one(push_action)
+ with opentracing.start_active_span(
+ "http-push",
+ tags={
+ "authenticated_entity": self.user_id,
+ "event_id": push_action["event_id"],
+ "app_id": self.app_id,
+ "app_display_name": self.app_display_name,
+ },
+ ):
+ processed = yield self._process_one(push_action)
+
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 3dfd527849..5b16ab4ae8 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -136,10 +136,11 @@ class Mailer(object):
group together multiple email sending attempts
sid (str): The generated session ID
"""
+ params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
self.hs.config.public_baseurl
- + "_matrix/client/unstable/password_reset/email/submit_token"
- "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid)
+ + "_matrix/client/unstable/password_reset/email/submit_token?%s"
+ % urllib.parse.urlencode(params)
)
template_vars = {"link": link}
@@ -163,10 +164,11 @@ class Mailer(object):
group together multiple email sending attempts
sid (str): The generated session ID
"""
+ params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
self.hs.config.public_baseurl
- + "_matrix/client/unstable/registration/email/submit_token"
- "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid)
+ + "_matrix/client/unstable/registration/email/submit_token?%s"
+ % urllib.parse.urlencode(params)
)
template_vars = {"link": link}
@@ -178,6 +180,35 @@ class Mailer(object):
)
@defer.inlineCallbacks
+ def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+ """Send an email with a validation link to a user for adding a 3pid to their account
+
+ Args:
+ email_address (str): Email address we're sending the validation link to
+
+ token (str): Unique token generated by the server to verify the email was received
+
+ client_secret (str): Unique token generated by the client to group together
+ multiple email sending attempts
+
+ sid (str): The generated session ID
+ """
+ params = {"token": token, "client_secret": client_secret, "sid": sid}
+ link = (
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/add_threepid/email/submit_token?%s"
+ % urllib.parse.urlencode(params)
+ )
+
+ template_vars = {"link": link}
+
+ yield self.send_email(
+ email_address,
+ "[%s] Validate Your Email" % self.hs.config.server_name,
+ template_vars,
+ )
+
+ @defer.inlineCallbacks
def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason
):
@@ -280,7 +311,7 @@ class Mailer(object):
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
- logger.info("Sending email notification to %s" % email_address)
+ logger.info("Sending email to %s" % email_address)
yield make_deferred_yieldable(
self.sendmail(
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index ec0ac547c1..aa7da1c543 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import List, Set
from pkg_resources import (
DistributionNotFound,
@@ -72,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
+ "typing-extensions>=3.7.4",
]
CONDITIONAL_REQUIREMENTS = {
@@ -97,7 +99,7 @@ CONDITIONAL_REQUIREMENTS = {
"jwt": ["pyjwt>=1.6.4"],
}
-ALL_OPTIONAL_REQUIREMENTS = set()
+ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement.
@@ -143,16 +145,26 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed %s, got %s==%s"
- % (dependency, e.dist.project_name, e.dist.version)
+ % (
+ dependency,
+ e.dist.project_name, # type: ignore[attr-defined] # noqa
+ e.dist.version, # type: ignore[attr-defined] # noqa
+ )
)
except DistributionNotFound:
deps_needed.append(dependency)
- errors.append("Needed %s but it was not installed" % (dependency,))
+ if for_feature:
+ errors.append(
+ "Needed %s for the '%s' feature but it was not installed"
+ % (dependency, for_feature)
+ )
+ else:
+ errors.append("Needed %s but it was not installed" % (dependency,))
if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be
# installed.
- OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
+ OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS:
try:
@@ -161,15 +173,19 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed optional %s, got %s==%s"
- % (dependency, e.dist.project_name, e.dist.version)
+ % (
+ dependency,
+ e.dist.project_name, # type: ignore[attr-defined] # noqa
+ e.dist.version, # type: ignore[attr-defined] # noqa
+ )
)
except DistributionNotFound:
# If it's not found, we don't care
pass
if deps_needed:
- for e in errors:
- logging.error(e)
+ for err in errors:
+ logging.error(err)
raise DependencyException(deps_needed)
diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html
new file mode 100644
index 0000000000..cc4ab07e09
--- /dev/null
+++ b/synapse/res/templates/add_threepid.html
@@ -0,0 +1,9 @@
+<html>
+<body>
+ <p>A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:</p>
+
+ <a href="{{ link }}">{{ link }}</a>
+
+ <p>If this was not you, you can safely ignore this email. Thank you.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/add_threepid.txt b/synapse/res/templates/add_threepid.txt
new file mode 100644
index 0000000000..a60c1ff659
--- /dev/null
+++ b/synapse/res/templates/add_threepid.txt
@@ -0,0 +1,6 @@
+A request to add an email address to your Matrix account has been received. If this was you,
+please click the link below to confirm adding this email:
+
+{{ link }}
+
+If this was not you, you can safely ignore this email. Thank you.
diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html
new file mode 100644
index 0000000000..441d11c846
--- /dev/null
+++ b/synapse/res/templates/add_threepid_failure.html
@@ -0,0 +1,8 @@
+<html>
+<head></head>
+<body>
+<p>The request failed for the following reason: {{ failure_reason }}.</p>
+
+<p>No changes have been made to your account.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html
new file mode 100644
index 0000000000..fbd6e4018f
--- /dev/null
+++ b/synapse/res/templates/add_threepid_success.html
@@ -0,0 +1,6 @@
+<html>
+<head></head>
+<body>
+<p>Your email has now been validated, please return to your client. You may now close this window.</p>
+</body>
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 81b6bd8816..939418ee2b 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -23,8 +23,6 @@ import re
from six import text_type
from six.moves import http_client
-from twisted.internet import defer
-
import synapse
from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -46,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import UserAdminServlet
from synapse.types import UserID, create_requester
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -59,15 +58,14 @@ class UsersRestServlet(RestServlet):
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
- yield assert_requester_is_admin(self.auth, request)
+ await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
- ret = yield self.handlers.admin_handler.get_users()
+ ret = await self.handlers.admin_handler.get_users()
return 200, ret
@@ -122,8 +120,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce}
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
@@ -204,14 +201,14 @@ class UserRegisterServlet(RestServlet):
register = RegisterRestServlet(self.hs)
- user_id = yield register.registration_handler.register_user(
+ user_id = await register.registration_handler.register_user(
localpart=body["username"].lower(),
password=body["password"],
admin=bool(admin),
user_type=user_type,
)
- result = yield register._create_registration_details(user_id, body)
+ result = await register._create_registration_details(user_id, body)
return 200, result
@@ -223,19 +220,18 @@ class WhoisRestServlet(RestServlet):
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
if target_user != auth_user:
- yield assert_user_is_admin(self.auth, auth_user)
+ await assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
- ret = yield self.handlers.admin_handler.get_whois(target_user)
+ ret = await self.handlers.admin_handler.get_whois(target_user)
return 200, ret
@@ -255,9 +251,8 @@ class PurgeHistoryRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, room_id, event_id):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
@@ -270,12 +265,12 @@ class PurgeHistoryRestServlet(RestServlet):
event_id = body.get("purge_up_to_event_id")
if event_id is not None:
- event = yield self.store.get_event(event_id)
+ event = await self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- token = yield self.store.get_topological_token_for_event(event_id)
+ token = await self.store.get_topological_token_for_event(event_id)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:
@@ -285,12 +280,10 @@ class PurgeHistoryRestServlet(RestServlet):
400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
)
- stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
+ stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
- r = (
- yield self.store.get_room_event_after_stream_ordering(
- room_id, stream_ordering
- )
+ r = await self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering
)
if not r:
logger.warn(
@@ -318,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON,
)
- purge_id = yield self.pagination_handler.start_purge_history(
+ purge_id = self.pagination_handler.start_purge_history(
room_id, token, delete_local_events=delete_local_events
)
@@ -339,9 +332,8 @@ class PurgeHistoryStatusRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, purge_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_GET(self, request, purge_id):
+ await assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id)
if purge_status is None:
@@ -357,9 +349,8 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, target_user_id):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -371,7 +362,7 @@ class DeactivateAccountRestServlet(RestServlet):
UserID.from_string(target_user_id)
- result = yield self._deactivate_account_handler.deactivate_account(
+ result = await self._deactivate_account_handler.deactivate_account(
target_user_id, erase
)
if result:
@@ -405,10 +396,9 @@ class ShutdownRoomRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
@@ -419,7 +409,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = yield self._room_creation_handler.create_room(
+ info = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@@ -438,9 +428,9 @@ class ShutdownRoomRestServlet(RestServlet):
# This will work even if the room is already blocked, but that is
# desirable in case the first attempt at blocking the room failed below.
- yield self.store.block_room(room_id, requester_user_id)
+ await self.store.block_room(room_id, requester_user_id)
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
for user_id in users:
@@ -451,7 +441,7 @@ class ShutdownRoomRestServlet(RestServlet):
try:
target_requester = create_requester(user_id)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@@ -461,9 +451,9 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False,
)
- yield self.room_member_handler.forget(target_requester.user, room_id)
+ await self.room_member_handler.forget(target_requester.user, room_id)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=new_room_id,
@@ -480,7 +470,7 @@ class ShutdownRoomRestServlet(RestServlet):
)
failed_to_kick_users.append(user_id)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
@@ -491,9 +481,11 @@ class ShutdownRoomRestServlet(RestServlet):
ratelimit=False,
)
- aliases_for_room = yield self.store.get_aliases_for_room(room_id)
+ aliases_for_room = await maybe_awaitable(
+ self.store.get_aliases_for_room(room_id)
+ )
- yield self.store.update_aliases_for_room(
+ await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
)
@@ -532,13 +524,12 @@ class ResetPasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self._set_password_handler = hs.get_set_password_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
+ async def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
UserID.from_string(target_user_id)
@@ -546,7 +537,7 @@ class ResetPasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"]
- yield self._set_password_handler.set_password(
+ await self._set_password_handler.set_password(
target_user_id, new_password, requester
)
return 200, {}
@@ -572,12 +563,11 @@ class GetUsersPaginatedRestServlet(RestServlet):
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
- @defer.inlineCallbacks
- def on_GET(self, request, target_user_id):
+ async def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse.
This needs user to have administrator access in Synapse.
"""
- yield assert_requester_is_admin(self.auth, request)
+ await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id)
@@ -590,11 +580,10 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start)
- ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
+ ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
+ async def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse..
This needs user to have administrator access in Synapse.
Example:
@@ -608,7 +597,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- yield assert_requester_is_admin(self.auth, request)
+ await assert_requester_is_admin(self.auth, request)
UserID.from_string(target_user_id)
order = "name" # order by name in user table
@@ -618,7 +607,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
start = params["start"]
logger.info("limit: %s, start: %s", limit, start)
- ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
+ ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret
@@ -641,13 +630,12 @@ class SearchUsersRestServlet(RestServlet):
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
- @defer.inlineCallbacks
- def on_GET(self, request, target_user_id):
+ async def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
search term.
This needs user to have a administrator access in Synapse.
"""
- yield assert_requester_is_admin(self.auth, request)
+ await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id)
@@ -661,7 +649,7 @@ class SearchUsersRestServlet(RestServlet):
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
- ret = yield self.handlers.admin_handler.search_users(term)
+ ret = await self.handlers.admin_handler.search_users(term)
return 200, ret
@@ -676,15 +664,14 @@ class DeleteGroupAdminRestServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only delete local groups")
- yield self.group_server.delete_group(group_id, requester.user.to_string())
+ await self.group_server.delete_group(group_id, requester.user.to_string())
return 200, {}
@@ -700,16 +687,15 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
if "user_id" not in body:
raise SynapseError(400, "Missing property 'user_id' in the request body")
- expiration_ts = yield self.account_activity_handler.renew_account_for_user(
+ expiration_ts = await self.account_activity_handler.renew_account_for_user(
body["user_id"],
body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 5a9b08d3ef..afd0647205 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,8 +15,6 @@
import re
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
@@ -42,8 +40,7 @@ def historical_admin_path_patterns(path_regex):
)
-@defer.inlineCallbacks
-def assert_requester_is_admin(auth, request):
+async def assert_requester_is_admin(auth, request):
"""Verify that the requester is an admin user
WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@@ -58,12 +55,11 @@ def assert_requester_is_admin(auth, request):
Raises:
AuthError if the requester is not an admin
"""
- requester = yield auth.get_user_by_req(request)
- yield assert_user_is_admin(auth, requester.user)
+ requester = await auth.get_user_by_req(request)
+ await assert_user_is_admin(auth, requester.user)
-@defer.inlineCallbacks
-def assert_user_is_admin(auth, user_id):
+async def assert_user_is_admin(auth, user_id):
"""Verify that the given user is an admin user
WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@@ -79,6 +75,6 @@ def assert_user_is_admin(auth, user_id):
AuthError if the user is not an admin
"""
- is_admin = yield auth.is_server_admin(user_id)
+ is_admin = await auth.is_server_admin(user_id)
if not is_admin:
raise AuthError(403, "You are not a server admin")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ed7086d09c..fa833e54cf 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_integer
from synapse.rest.admin._base import (
@@ -40,12 +38,11 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
- num_quarantined = yield self.store.quarantine_media_ids_in_room(
+ num_quarantined = await self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string()
)
@@ -62,14 +59,13 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
- local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+ local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
return 200, {"local": local_mxcs, "remote": remote_mxcs}
@@ -81,14 +77,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
logger.info("before_ts: %r", before_ts)
- ret = yield self.media_repository.delete_old_remote_media(before_ts)
+ ret = await self.media_repository.delete_old_remote_media(before_ts)
return 200, ret
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index ae2cbe2e0a..6e9a874121 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -14,8 +14,6 @@
# limitations under the License.
import re
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -69,9 +67,8 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- @defer.inlineCallbacks
- def on_POST(self, request, txn_id=None):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, txn_id=None):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message)
@@ -85,7 +82,7 @@ class SendServerNoticeServlet(RestServlet):
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = yield self.snm.send_notice(
+ event = await self.snm.send_notice(
user_id=body["user_id"],
type=event_type,
state_key=state_key,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 9720a3bab0..d5d124a0dc 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -14,8 +14,6 @@
# limitations under the License.
import re
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -59,24 +57,22 @@ class UserAdminServlet(RestServlet):
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Only local users can be admins of this homeserver")
- is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user)
+ is_admin = await self.handlers.admin_handler.get_user_server_admin(target_user)
is_admin = bool(is_admin)
return 200, {"admin": is_admin}
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
target_user = UserID.from_string(user_id)
@@ -93,7 +89,7 @@ class UserAdminServlet(RestServlet):
if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.")
- yield self.handlers.admin_handler.set_user_server_admin(
+ await self.handlers.admin_handler.set_user_server_admin(
target_user, set_admin_to
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 25a1b67092..8414af08cb 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -29,6 +29,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -376,6 +377,7 @@ class CasTicketServlet(RestServlet):
super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
+ self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client()
@@ -399,6 +401,7 @@ class CasTicketServlet(RestServlet):
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
+ displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
@@ -413,7 +416,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url
+ user, request, client_redirect_url, displayname
)
def parse_cas_response(self, cas_response_body):
@@ -507,6 +510,19 @@ class SSOAuthHandler(object):
localpart=localpart, default_display_name=user_display_name
)
+ self.complete_sso_login(registered_user_id, request, client_redirect_url)
+
+ def complete_sso_login(
+ self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+ ):
+ """Having figured out a mxid for this user, complete the HTTP request
+
+ Args:
+ registered_user_id:
+ request:
+ client_redirect_url:
+ """
+
login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 3582259026..9c1d41421c 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -39,6 +39,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
@@ -81,6 +82,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
)
def on_PUT(self, request, txn_id):
+ set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
@defer.inlineCallbacks
@@ -181,6 +183,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester = yield self.auth.get_user_by_req(request)
+ if txn_id:
+ set_tag("txn_id", txn_id)
+
content = parse_json_object_from_request(request)
event_dict = {
@@ -209,6 +214,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
ret = {}
if event:
+ set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
return 200, ret
@@ -244,12 +250,15 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id
)
+ set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
)
@@ -310,6 +319,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
@@ -350,6 +361,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
@@ -387,6 +402,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
@@ -655,6 +674,8 @@ class RoomForgetRestServlet(TransactionRestServlet):
return 200, {}
def on_PUT(self, request, room_id, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
)
@@ -701,6 +722,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
+ content.get("id_access_token"),
)
return 200, {}
@@ -737,6 +759,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return True
def on_PUT(self, request, room_id, membership_action, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
)
@@ -770,9 +794,12 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id,
)
+ set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 785d01ea52..80cf7126a0 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -103,16 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- # Have the configured identity server handle the request
- if not self.hs.config.account_threepid_delegate_email:
- logger.warn(
- "No upstream email account_threepid_delegate configured on the server to "
- "handle this request"
- )
- raise SynapseError(
- 400, "Password reset by email is not supported on this homeserver"
- )
+ assert self.hs.config.account_threepid_delegate_email
+ # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
@@ -136,71 +129,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret
-class MsisdnPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
-
- def __init__(self, hs):
- super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
- self.hs = hs
- self.datastore = self.hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- body = parse_json_object_from_request(request)
-
- assert_params_in_dict(
- body, ["client_secret", "country", "phone_number", "send_attempt"]
- )
- client_secret = body["client_secret"]
- country = body["country"]
- phone_number = body["phone_number"]
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
-
- msisdn = phone_number_to_msisdn(country, phone_number)
-
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
- raise SynapseError(
- 403,
- "Account phone numbers are not authorized on this server",
- Codes.THREEPID_DENIED,
- )
-
- existing_user_id = yield self.datastore.get_user_id_by_threepid(
- "msisdn", msisdn
- )
-
- if existing_user_id is None:
- raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
-
- if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
- "No upstream msisdn account_threepid_delegate configured on the server to "
- "handle this request"
- )
- raise SynapseError(
- 400,
- "Password reset by phone number is not supported on this homeserver",
- )
-
- ret = yield self.identity_handler.requestMsisdnToken(
- self.hs.config.account_threepid_delegate_msisdn,
- country,
- phone_number,
- client_secret,
- send_attempt,
- next_link,
- )
-
- return 200, ret
-
-
class PasswordResetSubmitTokenServlet(RestServlet):
"""Handles 3PID validation token submission"""
PATTERNS = client_patterns(
- "/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True
+ "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
def __init__(self, hs):
@@ -214,6 +147,11 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ self.failure_email_template, = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_password_reset_template_failure_html],
+ )
@defer.inlineCallbacks
def on_GET(self, request, medium):
@@ -261,34 +199,12 @@ class PasswordResetSubmitTokenServlet(RestServlet):
request.setResponseCode(e.code)
# Show a failure page with a reason
- html_template, = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_password_reset_template_failure_html],
- )
-
template_vars = {"failure_reason": e.msg}
- html = html_template.render(**template_vars)
+ html = self.failure_email_template.render(**template_vars)
request.write(html.encode("utf-8"))
finish_request(request)
- @defer.inlineCallbacks
- def on_POST(self, request, medium):
- if medium != "email":
- raise SynapseError(
- 400, "This medium is currently not supported for password resets"
- )
-
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["sid", "client_secret", "token"])
-
- valid, _ = yield self.store.validate_threepid_session(
- body["sid"], body["client_secret"], body["token"], self.clock.time_msec()
- )
- response_code = 200 if valid else 400
-
- return response_code, {"success": valid}
-
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
@@ -325,9 +241,7 @@ class PasswordRestServlet(RestServlet):
else:
requester = None
result, params, _ = yield self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
- body,
- self.hs.get_ip_from_request(request),
+ [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
)
if LoginType.EMAIL_IDENTITY in result:
@@ -416,13 +330,35 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.store = self.hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ template_html, template_text = load_jinja2_templates(
+ self.config.email_template_dir,
+ [
+ self.config.email_add_threepid_template_html,
+ self.config.email_add_threepid_template_text,
+ ],
+ public_baseurl=self.config.public_baseurl,
+ )
+ self.mailer = Mailer(
+ hs=self.hs,
+ app_name=self.config.email_app_name,
+ template_html=template_html,
+ template_text=template_text,
+ )
+
@defer.inlineCallbacks
def on_POST(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warn(
+ "Adding emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Adding an email to your account is disabled on this server"
+ )
+
body = parse_json_object_from_request(request)
- assert_params_in_dict(
- body, ["id_server", "client_secret", "email", "send_attempt"]
- )
- id_server = "https://" + body["id_server"] # Assume https
+ assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
client_secret = body["client_secret"]
email = body["email"]
send_attempt = body["send_attempt"]
@@ -442,9 +378,30 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if existing_user_id is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- ret = yield self.identity_handler.requestEmailToken(
- id_server, email, client_secret, send_attempt, next_link
- )
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = yield self.identity_handler.requestEmailToken(
+ self.hs.config.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send threepid validation emails from Synapse
+ sid = yield self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_add_threepid_mail,
+ next_link,
+ )
+
+ # Wrap the session id in a JSON object
+ ret = {"sid": sid}
+
return 200, ret
@@ -461,10 +418,8 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
- body,
- ["id_server", "client_secret", "country", "phone_number", "send_attempt"],
+ body, ["client_secret", "country", "phone_number", "send_attempt"]
)
- id_server = "https://" + body["id_server"] # Assume https
client_secret = body["client_secret"]
country = body["country"]
phone_number = body["phone_number"]
@@ -485,12 +440,146 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if existing_user_id is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
+ if not self.hs.config.account_threepid_delegate_msisdn:
+ logger.warn(
+ "No upstream msisdn account_threepid_delegate configured on the server to "
+ "handle this request"
+ )
+ raise SynapseError(
+ 400,
+ "Adding phone numbers to user account is not supported by this homeserver",
+ )
+
ret = yield self.identity_handler.requestMsisdnToken(
- id_server, country, phone_number, client_secret, send_attempt, next_link
+ self.hs.config.account_threepid_delegate_msisdn,
+ country,
+ phone_number,
+ client_secret,
+ send_attempt,
+ next_link,
)
+
return 200, ret
+class AddThreepidEmailSubmitTokenServlet(RestServlet):
+ """Handles 3PID validation token submission for adding an email to a user's account"""
+
+ PATTERNS = client_patterns(
+ "/add_threepid/email/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ self.failure_email_template, = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_add_threepid_template_failure_html],
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warn(
+ "Adding emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Adding an email to your account is disabled on this server"
+ )
+ elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating threepids. Use an identity server "
+ "instead.",
+ )
+
+ sid = parse_string(request, "sid", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ token = parse_string(request, "token", required=True)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = yield self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warn(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
+
+ # Otherwise show the success template
+ html = self.config.email_add_threepid_template_success_html_content
+ request.setResponseCode(200)
+ except ThreepidValidationError as e:
+ request.setResponseCode(e.code)
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html = self.failure_email_template.render(**template_vars)
+
+ request.write(html.encode("utf-8"))
+ finish_request(request)
+
+
+class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
+ """Handles 3PID validation token submission for adding a phone number to a user's
+ account
+ """
+
+ PATTERNS = client_patterns(
+ "/add_threepid/msisdn/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ if not self.config.account_threepid_delegate_msisdn:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating phone numbers. Use an identity server "
+ "instead.",
+ )
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["client_secret", "sid", "token"])
+
+ # Proxy submit_token request to msisdn threepid delegate
+ response = yield self.identity_handler.proxy_msisdn_submit_token(
+ self.config.account_threepid_delegate_msisdn,
+ body["client_secret"],
+ body["sid"],
+ body["token"],
+ )
+ return 200, response
+
+
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
@@ -512,6 +601,8 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
@@ -519,33 +610,96 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(
400, "Missing param three_pid_creds", Codes.MISSING_PARAM
)
+ assert_params_in_dict(threepid_creds, ["client_secret", "sid"])
+
+ client_secret = threepid_creds["client_secret"]
+ sid = threepid_creds["sid"]
+
+ validation_session = yield self.identity_handler.validate_threepid_session(
+ client_secret, sid
+ )
+ if validation_session:
+ yield self.auth_handler.add_threepid(
+ user_id,
+ validation_session["medium"],
+ validation_session["address"],
+ validation_session["validated_at"],
+ )
+ return 200, {}
+
+ raise SynapseError(
+ 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
+ )
+
+class ThreepidAddRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidAddRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
- threepid = yield self.identity_handler.threepid_from_creds(threepid_creds)
-
- if not threepid:
- raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
+ assert_params_in_dict(body, ["client_secret", "sid"])
+ client_secret = body["client_secret"]
+ sid = body["sid"]
- for reqd in ["medium", "address", "validated_at"]:
- if reqd not in threepid:
- logger.warn("Couldn't add 3pid: invalid response from ID server")
- raise SynapseError(500, "Invalid response from ID Server")
+ validation_session = yield self.identity_handler.validate_threepid_session(
+ client_secret, sid
+ )
+ if validation_session:
+ yield self.auth_handler.add_threepid(
+ user_id,
+ validation_session["medium"],
+ validation_session["address"],
+ validation_session["validated_at"],
+ )
+ return 200, {}
- yield self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ raise SynapseError(
+ 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
- if "bind" in body and body["bind"]:
- logger.debug("Binding threepid %s to %s", threepid, user_id)
- yield self.identity_handler.bind_threepid(threepid_creds, user_id)
+
+class ThreepidBindRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidBindRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+ id_server = body["id_server"]
+ sid = body["sid"]
+ client_secret = body["client_secret"]
+ id_access_token = body.get("id_access_token") # optional
+
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ yield self.identity_handler.bind_threepid(
+ client_secret, sid, user_id, id_server, id_access_token
+ )
return 200, {}
class ThreepidUnbindRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/3pid/unbind$")
+ PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True)
def __init__(self, hs):
super(ThreepidUnbindRestServlet, self).__init__()
@@ -627,13 +781,16 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
+ AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
+ AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
+ ThreepidAddRestServlet(hs).register(http_server)
+ ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index c6ddf24c8d..17a8bc7366 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@@ -52,13 +52,15 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
- filter = yield self.filtering.get_user_filter(
+ filter_collection = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
)
+ except StoreError as e:
+ if e.code != 404:
+ raise
+ raise NotFoundError("No such filter")
- return 200, filter.get_filter_json()
- except (KeyError, StoreError):
- raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
+ return 200, filter_collection.get_filter_json()
class CreateFilterRestServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5c7a5f3579..4f24a124a6 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -16,6 +16,7 @@
import hmac
import logging
+from typing import List, Union
from six import string_types
@@ -31,9 +32,14 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.config import ConfigError
+from synapse.config.captcha import CaptchaConfig
+from synapse.config.consent_config import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
+from synapse.handlers.auth import AuthHandler
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -131,15 +137,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- if not self.hs.config.account_threepid_delegate_email:
- logger.warn(
- "No upstream email account_threepid_delegate configured on the server to "
- "handle this request"
- )
- raise SynapseError(
- 400, "Registration by email is not supported on this homeserver"
- )
+ assert self.hs.config.account_threepid_delegate_email
+ # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
@@ -246,6 +246,18 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ self.failure_email_template, = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_registration_template_failure_html],
+ )
+
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ self.failure_email_template, = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_registration_template_failure_html],
+ )
+
@defer.inlineCallbacks
def on_GET(self, request, medium):
if medium != "email":
@@ -289,17 +301,11 @@ class RegistrationSubmitTokenServlet(RestServlet):
request.setResponseCode(200)
except ThreepidValidationError as e:
- # Show a failure page with a reason
request.setResponseCode(e.code)
# Show a failure page with a reason
- html_template, = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
- )
-
template_vars = {"failure_reason": e.msg}
- html = html_template.render(**template_vars)
+ html = self.failure_email_template.render(**template_vars)
request.write(html.encode("utf-8"))
finish_request(request)
@@ -334,6 +340,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
yield wait_deferred
@@ -366,6 +377,10 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.clock = hs.get_clock()
+ self._registration_flows = _calculate_registration_flows(
+ hs.config, self.auth_handler
+ )
+
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
@@ -486,69 +501,8 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
- # FIXME: need a better error than "no auth flow found" for scenarios
- # where we required 3PID for registration but the user didn't give one
- require_email = "email" in self.hs.config.registrations_require_3pid
- require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid
-
- show_msisdn = True
- if self.hs.config.disable_msisdn_registration:
- show_msisdn = False
- require_msisdn = False
-
- flows = []
- if self.hs.config.enable_registration_captcha:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- # Also add a dummy flow here, otherwise if a client completes
- # recaptcha first we'll assume they were going for this flow
- # and complete the request, when they could have been trying to
- # complete one of the flows with email/msisdn auth.
- flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email:
- flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend(
- [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]
- )
- else:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- flows.extend([[LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email or require_msisdn:
- flows.extend([[LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]])
-
- # Append m.login.terms to all flows if we're requiring consent
- if self.hs.config.user_consent_at_registration:
- new_flows = []
- for flow in flows:
- inserted = False
- # m.login.terms should go near the end but before msisdn or email auth
- for i, stage in enumerate(flow):
- if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
- flow.insert(i, LoginType.TERMS)
- inserted = True
- break
- if not inserted:
- flow.append(LoginType.TERMS)
- flows.extend(new_flows)
-
auth_result, params, session_id = yield self.auth_handler.check_auth(
- flows, body, self.hs.get_ip_from_request(request)
+ self._registration_flows, body, self.hs.get_ip_from_request(request)
)
# Check that we're not trying to register a denied 3pid.
@@ -711,6 +665,83 @@ class RegisterRestServlet(RestServlet):
)
+def _calculate_registration_flows(
+ # technically `config` has to provide *all* of these interfaces, not just one
+ config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
+ auth_handler: AuthHandler,
+) -> List[List[str]]:
+ """Get a suitable flows list for registration
+
+ Args:
+ config: server configuration
+ auth_handler: authorization handler
+
+ Returns: a list of supported flows
+ """
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = "email" in config.registrations_require_3pid
+ require_msisdn = "msisdn" in config.registrations_require_3pid
+
+ show_msisdn = True
+ show_email = True
+
+ if config.disable_msisdn_registration:
+ show_msisdn = False
+ require_msisdn = False
+
+ enabled_auth_types = auth_handler.get_enabled_auth_types()
+ if LoginType.EMAIL_IDENTITY not in enabled_auth_types:
+ show_email = False
+ if require_email:
+ raise ConfigError(
+ "Configuration requires email address at registration, but email "
+ "validation is not configured"
+ )
+
+ if LoginType.MSISDN not in enabled_auth_types:
+ show_msisdn = False
+ if require_msisdn:
+ raise ConfigError(
+ "Configuration requires msisdn at registration, but msisdn "
+ "validation is not configured"
+ )
+
+ flows = []
+
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ # Add a dummy step here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.append([LoginType.DUMMY])
+
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if show_email and not require_msisdn:
+ flows.append([LoginType.EMAIL_IDENTITY])
+
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if show_msisdn and not require_email:
+ flows.append([LoginType.MSISDN])
+
+ if show_email and show_msisdn:
+ # always let users provide both MSISDN & email
+ flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
+
+ # Prepend m.login.terms to all flows if we're requiring consent
+ if config.user_consent_at_registration:
+ for flow in flows:
+ flow.insert(0, LoginType.TERMS)
+
+ # Prepend recaptcha to all flows if we're requiring captcha
+ if config.enable_registration_captcha:
+ for flow in flows:
+ flow.insert(0, LoginType.RECAPTCHA)
+
+ return flows
+
+
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index df4f44cd36..d596786430 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -375,7 +375,7 @@ class RoomKeysVersionServlet(RestServlet):
"ed25519:something": "hijklmnop"
}
},
- "version": "42"
+ "version": "12345"
}
HTTP/1.1 200 OK
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index c98c5a3802..a883c8adda 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import PresenceState
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
@@ -119,25 +119,32 @@ class SyncRestServlet(RestServlet):
request_key = (user, timeout, since, filter_id, full_state, device_id)
- if filter_id:
- if filter_id.startswith("{"):
- try:
- filter_object = json.loads(filter_id)
- set_timeline_upper_limit(
- filter_object, self.hs.config.filter_timeline_limit
- )
- except Exception:
- raise SynapseError(400, "Invalid filter JSON")
- self.filtering.check_valid_filter(filter_object)
- filter = FilterCollection(filter_object)
- else:
- filter = yield self.filtering.get_user_filter(user.localpart, filter_id)
+ if filter_id is None:
+ filter_collection = DEFAULT_FILTER_COLLECTION
+ elif filter_id.startswith("{"):
+ try:
+ filter_object = json.loads(filter_id)
+ set_timeline_upper_limit(
+ filter_object, self.hs.config.filter_timeline_limit
+ )
+ except Exception:
+ raise SynapseError(400, "Invalid filter JSON")
+ self.filtering.check_valid_filter(filter_object)
+ filter_collection = FilterCollection(filter_object)
else:
- filter = DEFAULT_FILTER_COLLECTION
+ try:
+ filter_collection = yield self.filtering.get_user_filter(
+ user.localpart, filter_id
+ )
+ except StoreError as err:
+ if err.code != 404:
+ raise
+ # fix up the description and errcode to be more useful
+ raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM)
sync_config = SyncConfig(
user=user,
- filter_collection=filter,
+ filter_collection=filter_collection,
is_guest=requester.is_guest,
request_key=request_key,
device_id=device_id,
@@ -171,7 +178,7 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec()
response_content = yield self.encode_response(
- time_now, sync_result, requester.access_token_id, filter
+ time_now, sync_result, requester.access_token_id, filter_collection
)
return 200, response_content
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0058b6b459..1044ae7b4e 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -48,7 +48,24 @@ class VersionsRestServlet(RestServlet):
"r0.5.0",
],
# as per MSC1497:
- "unstable_features": {"m.lazy_load_members": True},
+ "unstable_features": {
+ "m.lazy_load_members": True,
+ # as per MSC2190, as amended by MSC2264
+ # to be removed in r0.6.0
+ "m.id_access_token": True,
+ # Advertise to clients that they need not include an `id_server`
+ # parameter during registration or password reset, as Synapse now decides
+ # itself which identity server to use (or none at all).
+ #
+ # This is also used by a client when they wish to bind a 3PID to their
+ # account, but not bind it to an identity server, the endpoint for which
+ # also requires `id_server`. If the homeserver is handling 3PID
+ # verification itself, there is no need to ask the user for `id_server` to
+ # be supplied.
+ "m.require_identity_server": False,
+ # as per MSC2290
+ "m.separate_add_and_bind": True,
+ },
},
)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 5fefee4dde..65bbf00073 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -195,7 +195,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
- logger.debug("Responding to media request with responder %s")
+ logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 7a56cd4b6c..0c68c3aad5 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -270,7 +270,7 @@ class PreviewUrlResource(DirectServeResource):
logger.debug("Calculated OG for %s as %s" % (url, og))
- jsonog = json.dumps(og).encode("utf8")
+ jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@@ -283,7 +283,7 @@ class PreviewUrlResource(DirectServeResource):
media_info["created_ts"],
)
- return jsonog
+ return jsonog.encode("utf8")
@defer.inlineCallbacks
def _download_url(self, url, user):
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index c995d7e043..8cf415e29d 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -82,13 +82,21 @@ class Thumbnailer(object):
else:
return (max_height * self.width) // self.height, max_height
+ def _resize(self, width, height):
+ # 1-bit or 8-bit color palette images need converting to RGB
+ # otherwise they will be scaled using nearest neighbour which
+ # looks awful
+ if self.image.mode in ["1", "P"]:
+ self.image = self.image.convert("RGB")
+ return self.image.resize((width, height), Image.ANTIALIAS)
+
def scale(self, width, height, output_type):
"""Rescales the image to the given dimensions.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
- scaled = self.image.resize((width, height), Image.ANTIALIAS)
+ scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
@@ -107,13 +115,13 @@ class Thumbnailer(object):
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
- scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS)
+ scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
- scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS)
+ scaled_image = self._resize(scaled_width, height)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 5d76bbdf68..83d005812d 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -17,7 +17,7 @@ import logging
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
from synapse.http.server import (
DirectServeResource,
respond_with_json,
@@ -56,7 +56,11 @@ class UploadResource(DirectServeResource):
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
- raise SynapseError(msg="Upload request body is too large", code=413)
+ raise SynapseError(
+ msg="Upload request body is too large",
+ code=413,
+ errcode=Codes.TOO_LARGE,
+ )
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
diff --git a/synapse/server.py b/synapse/server.py
index 9e28dba2b1..1fcc7375d3 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -221,6 +221,7 @@ class HomeServer(object):
self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
+ self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
self.datastore = None
@@ -279,6 +280,9 @@ class HomeServer(object):
def get_registration_ratelimiter(self):
return self.registration_ratelimiter
+ def get_admin_redaction_ratelimiter(self):
+ return self.admin_redaction_ratelimiter
+
def build_federation_client(self):
return FederationClient(self)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2b0f4c79ee..dc9f5a9008 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -33,7 +33,7 @@ from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.metrics import Measure
+from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@@ -191,11 +191,22 @@ class StateHandler(object):
return joined_users
@defer.inlineCallbacks
- def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
- if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
- entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ def get_current_hosts_in_room(self, room_id):
+ event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
+
+ @defer.inlineCallbacks
+ def get_hosts_in_room_at_events(self, room_id, event_ids):
+ """Get the hosts that were in a room at the given event ids
+
+ Args:
+ room_id (str):
+ event_ids (list[str]):
+
+ Returns:
+ Deferred[list[str]]: the hosts in the room at the given events
+ """
+ entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
return joined_hosts
@@ -344,6 +355,7 @@ class StateHandler(object):
return context
+ @measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index e02663f50e..276c271bbe 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -62,7 +62,7 @@ var show_login = function() {
$("#sso_flow").show();
}
- if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas) {
+ if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) {
$("#no_login_types").show();
}
};
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abe16334ec..f5906fcd54 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -20,6 +20,7 @@ import random
import sys
import threading
import time
+from typing import Iterable, Tuple
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import builtins, intern, range
@@ -30,7 +31,7 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
@@ -550,8 +551,9 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ result = yield make_deferred_yieldable(
+ self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ )
return result
@@ -1162,19 +1164,18 @@ class SQLBaseStore(object):
if not iterable:
return []
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-
- clauses = []
- values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join(clauses),
+ )
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
@@ -1323,10 +1324,8 @@ class SQLBaseStore(object):
sql = "DELETE FROM %s" % table
- clauses = []
- values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
@@ -1693,3 +1692,30 @@ def db_to_json(db_content):
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
+
+
+def make_in_list_sql_clause(
+ database_engine, column: str, iterable: Iterable
+) -> Tuple[str, Iterable]:
+ """Returns an SQL clause that checks the given column is in the iterable.
+
+ On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+ it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+ using the `ANY` form on postgres means that it views queries with
+ different length iterables as the same, helping the query stats.
+
+ Args:
+ database_engine
+ column: Name of the column
+ iterable: The values to check the column against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+
+ if database_engine.supports_using_any_list:
+ # This should hopefully be faster, but also makes postgres query
+ # stats easier to understand.
+ return "%s = ANY(?)" % (column,), [list(iterable)]
+ else:
+ return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index e5f0668f09..80b57a948c 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -140,7 +140,7 @@ class BackgroundUpdateStore(SQLBaseStore):
"background_updates",
keyvalues=None,
retcol="1",
- desc="check_background_updates",
+ desc="has_completed_background_updates",
)
if not updates:
self._all_done = True
@@ -148,6 +148,26 @@ class BackgroundUpdateStore(SQLBaseStore):
return False
+ async def has_completed_background_update(self, update_name) -> bool:
+ """Check if the given background update has finished running.
+ """
+
+ if self._all_done:
+ return True
+
+ if update_name in self._background_update_queue:
+ return False
+
+ update_exists = await self._simple_select_one_onecol(
+ "background_updates",
+ keyvalues={"update_name": update_name},
+ retcol="1",
+ desc="has_completed_background_update",
+ allow_none=True,
+ )
+
+ return not update_exists
+
@defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
@@ -218,7 +238,7 @@ class BackgroundUpdateStore(SQLBaseStore):
duration_ms = time_stop - time_start
logger.info(
- "Updating %r. Updated %r items in %rms."
+ "Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name,
items_updated,
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 6db8c54077..067820a5da 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -19,7 +19,7 @@ from six import iteritems
from twisted.internet import defer
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR
from . import background_updates
@@ -33,14 +33,9 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
-class ClientIpStore(background_updates.BackgroundUpdateStore):
+class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
-
- self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
- )
-
- super(ClientIpStore, self).__init__(db_conn, hs)
+ super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"user_ips_device_index",
@@ -85,14 +80,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
+ # Update the last seen info in devices.
+ self.register_background_update_handler(
+ "devices_last_seen", self._devices_last_seen_update
)
@defer.inlineCallbacks
@@ -294,6 +284,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return batch_size
@defer.inlineCallbacks
+ def _devices_last_seen_update(self, progress, batch_size):
+ """Background update to insert last seen info into devices table
+ """
+
+ last_user_id = progress.get("last_user_id", "")
+ last_device_id = progress.get("last_device_id", "")
+
+ def _devices_last_seen_update_txn(txn):
+ # This consists of two queries:
+ #
+ # 1. The sub-query searches for the next N devices and joins
+ # against user_ips to find the max last_seen associated with
+ # that device.
+ # 2. The outer query then joins again against user_ips on
+ # user/device/last_seen. This *should* hopefully only
+ # return one row, but if it does return more than one then
+ # we'll just end up updating the same device row multiple
+ # times, which is fine.
+
+ if self.database_engine.supports_tuple_comparison:
+ where_clause = "(user_id, device_id) > (?, ?)"
+ where_args = [last_user_id, last_device_id]
+ else:
+ # We explicitly do a `user_id >= ? AND (...)` here to ensure
+ # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
+ # makes it hard for query optimiser to tell that it can use the
+ # index on user_id
+ where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
+ where_args = [last_user_id, last_user_id, last_device_id]
+
+ sql = """
+ SELECT
+ last_seen, ip, user_agent, user_id, device_id
+ FROM (
+ SELECT
+ user_id, device_id, MAX(u.last_seen) AS last_seen
+ FROM devices
+ INNER JOIN user_ips AS u USING (user_id, device_id)
+ WHERE %(where_clause)s
+ GROUP BY user_id, device_id
+ ORDER BY user_id ASC, device_id ASC
+ LIMIT ?
+ ) c
+ INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
+ """ % {
+ "where_clause": where_clause
+ }
+ txn.execute(sql, where_args + [batch_size])
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ sql = """
+ UPDATE devices
+ SET last_seen = ?, ip = ?, user_agent = ?
+ WHERE user_id = ? AND device_id = ?
+ """
+ txn.execute_batch(sql, rows)
+
+ _, _, _, user_id, device_id = rows[-1]
+ self._background_update_progress_txn(
+ txn,
+ "devices_last_seen",
+ {"last_user_id": user_id, "last_device_id": device_id},
+ )
+
+ return len(rows)
+
+ updated = yield self.runInteraction(
+ "_devices_last_seen_update", _devices_last_seen_update_txn
+ )
+
+ if not updated:
+ yield self._end_background_update("devices_last_seen")
+
+ return updated
+
+
+class ClientIpStore(ClientIpBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ )
+
+ super(ClientIpStore, self).__init__(db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
+
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
+ if self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+ @defer.inlineCallbacks
def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
@@ -314,20 +408,19 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
+ @wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
- def update():
- to_update = self._batch_row_update
- self._batch_row_update = {}
- return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
- return run_as_background_process("update_client_ips", update)
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or (
@@ -354,6 +447,21 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
},
lock=False,
)
+
+ # Technically an access token might not be associated with
+ # a device so we need to check.
+ if device_id:
+ self._simple_upsert_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ "ip": ip,
+ },
+ lock=False,
+ )
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
@@ -372,19 +480,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
keys giving the column names
"""
- res = yield self.runInteraction(
- "get_last_client_ip_by_device",
- self._get_last_client_ip_by_device_txn,
- user_id,
- device_id,
- retcols=(
- "user_id",
- "access_token",
- "ip",
- "user_agent",
- "device_id",
- "last_seen",
- ),
+ keyvalues = {"user_id": user_id}
+ if device_id is not None:
+ keyvalues["device_id"] = device_id
+
+ res = yield self._simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
@@ -403,42 +506,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
}
return ret
- @classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
- where_clauses = []
- bindings = []
- if device_id is None:
- where_clauses.append("user_id = ?")
- bindings.extend((user_id,))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
-
- if not where_clauses:
- return []
-
- inner_select = (
- "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
- "WHERE %(where)s "
- "GROUP BY user_id, device_id"
- ) % {"where": " OR ".join(where_clauses)}
-
- sql = (
- "SELECT %(retcols)s FROM user_ips "
- "JOIN (%(inner_select)s) ips ON"
- " user_ips.last_seen = ips.mls AND"
- " user_ips.user_id = ips.user_id AND"
- " (user_ips.device_id = ips.device_id OR"
- " (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
- " )"
- ) % {
- "retcols": ",".join("user_ips." + c for c in retcols),
- "inner_select": inner_select,
- }
-
- txn.execute(sql, bindings)
- return cls.cursor_to_dict(txn)
-
@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
user_id = user.to_string()
@@ -470,3 +537,45 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
}
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.has_completed_background_update("devices_last_seen"):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
+ )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 6b7458304e..f04aad0743 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -20,7 +20,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
@@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
-class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
+class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
- super(DeviceInboxStore, self).__init__(db_conn, hs)
+ super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_inbox_stream_index",
@@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
+ @defer.inlineCallbacks
+ def _background_drop_index_device_inbox(self, progress, batch_size):
+ def reindex_txn(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+ txn.close()
+
+ yield self.runWithConnection(reindex_txn)
+
+ yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+ return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+ DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+ def __init__(self, db_conn, hs):
+ super(DeviceInboxStore, self).__init__(db_conn, hs)
+
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
@@ -358,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
else:
if not devices:
continue
- sql = (
- "SELECT device_id FROM devices"
- " WHERE user_id = ? AND device_id IN ("
- + ",".join("?" * len(devices))
- + ")"
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "device_id", devices
)
+ sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
+
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
- txn.execute(sql, [user_id] + devices)
+ txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on
# this server
@@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
-
- @defer.inlineCallbacks
- def _background_drop_index_device_inbox(self, progress, batch_size):
- def reindex_txn(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
- txn.close()
-
- yield self.runWithConnection(reindex_txn)
-
- yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
- return 1
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 2db3ef4004..f7a3542348 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -30,7 +30,12 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import Cache, SQLBaseStore, db_to_json
+from synapse.storage._base import (
+ Cache,
+ SQLBaseStore,
+ db_to_json,
+ make_in_list_sql_clause,
+)
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -487,11 +492,14 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
- AND user_id IN (%s)
+ AND
"""
for chunk in batch_iter(to_check, 100):
- txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk)
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", chunk
+ )
+ txn.execute(sql + clause, (from_key,) + tuple(args))
changes.update(user_id for user_id, in txn)
return changes
@@ -573,17 +581,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
-class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
+class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
- super(DeviceStore, self).__init__(db_conn, hs)
-
- # Map of (user_id, device_id) -> bool. If there is an entry that implies
- # the device exists.
- self.device_id_exists_cache = Cache(
- name="device_id_exists", keylen=2, max_entries=10000
- )
-
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+ super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_lists_stream_idx",
@@ -617,6 +617,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
@defer.inlineCallbacks
+ def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ def f(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
+ txn.close()
+
+ yield self.runWithConnection(f)
+ yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
+ return 1
+
+
+class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(DeviceStore, self).__init__(db_conn, hs)
+
+ # Map of (user_id, device_id) -> bool. If there is an entry that implies
+ # the device exists.
+ self.device_id_exists_cache = Cache(
+ name="device_id_exists", keylen=2, max_entries=10000
+ )
+
+ self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+
+ @defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
@@ -987,15 +1012,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"_prune_old_outbound_device_pokes",
_prune_txn,
)
-
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
- def f(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
- txn.close()
-
- yield self.runWithConnection(f)
- yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
- return 1
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 625f95234f..2effe08592 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -42,7 +42,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
- dict containing "key_json", "device_display_name".
+ key data. The key data will be a dict in the same format as the
+ DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
set_tag("query_list", query_list)
if not query_list:
@@ -56,17 +57,25 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_deleted_devices,
)
+ # Build the result structure, un-jsonify the results, and add the
+ # "unsigned" section
+ rv = {}
for user_id, device_keys in iteritems(results):
+ rv[user_id] = {}
for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = db_to_json(device_info.pop("key_json"))
- # add cross-signing signatures to the keys
+ r = db_to_json(device_info.pop("key_json"))
+ r["unsigned"] = {}
+ display_name = device_info["device_display_name"]
+ if display_name is not None:
+ r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info:
for sig_user_id, sigs in device_info["signatures"].items():
- device_info["keys"].setdefault("signatures", {}).setdefault(
+ r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
+ rv[user_id][device_id] = r
- return results
+ return rv
@trace
def _get_e2e_device_keys_txn(
@@ -378,7 +387,8 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# The "keys" property must only have one entry, which will be the public
# key, so we just grab the first value in there
pubkey = next(iter(key["keys"].values()))
- self._simple_insert(
+ self._simple_insert_txn(
+ txn,
"devices",
values={
"user_id": user_id,
@@ -386,12 +396,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"display_name": key_type + " signing key",
"hidden": True,
},
- desc="store_master_key_device",
)
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
- self._simple_insert(
+ self._simple_insert_txn(
+ txn,
"e2e_cross_signing_keys",
values={
"user_id": user_id,
@@ -399,7 +409,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"keydata": json.dumps(key),
"stream_id": stream_id,
},
- desc="store_master_key",
)
def set_e2e_cross_signing_key(self, user_id, key_type, key):
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 289b6bc281..b7c4eda338 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -22,6 +22,13 @@ class PostgresEngine(object):
def __init__(self, database_module, database_config):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
+
+ # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+ # actually want to use bytes than wrap it in `bytearray`.
+ def _disable_bytes_adapter(_):
+ raise Exception("Passing bytes to DB is disabled.")
+
+ self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
@@ -72,6 +79,19 @@ class PostgresEngine(object):
"""
return True
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ return True
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
+ """
+ return True
+
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index e9b9caa49a..ddad17dc5a 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -38,6 +38,20 @@ class Sqlite3Engine(object):
"""
return self.module.sqlite_version_info >= (3, 24, 0)
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
+ SQLite 3.15+.
+ """
+ return self.module.sqlite_version_info >= (3, 15, 0)
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
+ """
+ return False
+
def check_database(self, txn):
pass
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 4f500d893e..47cc10d32a 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
import random
@@ -24,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached
@@ -67,7 +68,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
+ base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -75,7 +76,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
front_list = list(front)
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks:
- txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", chunk
+ )
+ txn.execute(base_sql + clause, list(args))
new_front.update([r[0] for r in txn])
new_front -= results
@@ -190,12 +194,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
room_id,
)
- def get_rooms_with_many_extremities(self, min_count, limit):
+ def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
"""Get the top rooms with at least N extremities.
Args:
min_count (int): The minimum number of extremities
limit (int): The maximum number of rooms to return.
+ room_id_filter (iterable[str]): room_ids to exclude from the results
Returns:
Deferred[list]: At most `limit` room IDs that have at least
@@ -203,15 +208,25 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
def _get_rooms_with_many_extremities_txn(txn):
+ where_clause = "1=1"
+ if room_id_filter:
+ where_clause = "room_id NOT IN (%s)" % (
+ ",".join("?" for _ in room_id_filter),
+ )
+
sql = """
SELECT room_id FROM event_forward_extremities
+ WHERE %s
GROUP BY room_id
HAVING count(*) > ?
ORDER BY count(*) DESC
LIMIT ?
- """
+ """ % (
+ where_clause,
+ )
- txn.execute(sql, (min_count, limit))
+ query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
+ txn.execute(sql, query_args)
return [room_id for room_id, in txn]
return self.runInteraction(
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 1958afe1d7..ee49ef235d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -33,11 +33,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events.utils import prune_event_dict
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
+from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
@@ -262,6 +264,14 @@ class EventsStore(
hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+ def _censor_redactions():
+ return run_as_background_process(
+ "_censor_redactions", self._censor_redactions
+ )
+
+ if self.hs.config.redaction_retention_period is not None:
+ hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
@defer.inlineCallbacks
def _read_forward_extremities(self):
def fetch(txn):
@@ -632,14 +642,16 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
- prev_event_id IN (%s)
- AND NOT events.outlier
+ NOT events.outlier
AND rejections.event_id IS NULL
- """ % (
- ",".join("?" for _ in batch),
+ AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "prev_event_id", batch
)
- txn.execute(sql, batch)
+ txn.execute(sql + clause, args)
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
@@ -686,13 +698,15 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
- event_id IN (%s)
- AND NOT events.outlier
- """ % (
- ",".join("?" for _ in to_recursively_check),
+ NOT events.outlier
+ AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", to_recursively_check
)
- txn.execute(sql, to_recursively_check)
+ txn.execute(sql + clause, args)
to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn:
@@ -1380,6 +1394,18 @@ class EventsStore(
],
)
+ for event, _ in events_and_contexts:
+ if not event.internal_metadata.is_redacted():
+ # If we're persisting an unredacted event we go and ensure
+ # that we mark any redactions that reference this event as
+ # requiring censoring.
+ self._simple_update_txn(
+ txn,
+ table="redactions",
+ keyvalues={"redacts": event.event_id},
+ updatevalues={"have_censored": False},
+ )
+
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
@@ -1522,10 +1548,14 @@ class EventsStore(
" FROM events as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(ev_map)),)
+ " WHERE "
+ )
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "e.event_id", list(ev_map)
+ )
- txn.execute(sql, list(ev_map))
+ txn.execute(sql + clause, args)
rows = self.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
@@ -1543,12 +1573,101 @@ class EventsStore(
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
- txn.execute(
- "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts),
+
+ self._simple_insert_txn(
+ txn,
+ table="redactions",
+ values={
+ "event_id": event.event_id,
+ "redacts": event.redacts,
+ "received_ts": self._clock.time_msec(),
+ },
)
@defer.inlineCallbacks
+ def _censor_redactions(self):
+ """Censors all redactions older than the configured period that haven't
+ been censored yet.
+
+ By censor we mean update the event_json table with the redacted event.
+
+ Returns:
+ Deferred
+ """
+
+ if self.hs.config.redaction_retention_period is None:
+ return
+
+ before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+ # We fetch all redactions that:
+ # 1. point to an event we have,
+ # 2. has a received_ts from before the cut off, and
+ # 3. we haven't yet censored.
+ #
+ # This is limited to 100 events to ensure that we don't try and do too
+ # much at once. We'll get called again so this should eventually catch
+ # up.
+ sql = """
+ SELECT redactions.event_id, redacts FROM redactions
+ LEFT JOIN events AS original_event ON (
+ redacts = original_event.event_id
+ )
+ WHERE NOT have_censored
+ AND redactions.received_ts <= ?
+ ORDER BY redactions.received_ts ASC
+ LIMIT ?
+ """
+
+ rows = yield self._execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
+
+ updates = []
+
+ for redaction_id, event_id in rows:
+ redaction_event = yield self.get_event(redaction_id, allow_none=True)
+ original_event = yield self.get_event(
+ event_id, allow_rejected=True, allow_none=True
+ )
+
+ # The SQL above ensures that we have both the redaction and
+ # original event, so if the `get_event` calls return None it
+ # means that the redaction wasn't allowed. Either way we know that
+ # the result won't change so we mark the fact that we've checked.
+ if (
+ redaction_event
+ and original_event
+ and original_event.internal_metadata.is_redacted()
+ ):
+ # Redaction was allowed
+ pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
+ else:
+ # Redaction wasn't allowed
+ pruned_json = None
+
+ updates.append((redaction_id, event_id, pruned_json))
+
+ def _update_censor_txn(txn):
+ for redaction_id, event_id, pruned_json in updates:
+ if pruned_json:
+ self._simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
+ )
+
+ self._simple_update_one_txn(
+ txn,
+ table="redactions",
+ keyvalues={"event_id": redaction_id},
+ updatevalues={"have_censored": True},
+ )
+
+ yield self.runInteraction("_update_censor_txn", _update_censor_txn)
+
+ @defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -2139,11 +2258,12 @@ class EventsStore(
sql = """
SELECT DISTINCT state_group FROM event_to_state_groups
LEFT JOIN events_to_purge AS ep USING (event_id)
- WHERE state_group IN (%s) AND ep.event_id IS NULL
- """ % (
- ",".join("?" for _ in current_search),
+ WHERE ep.event_id IS NULL AND
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "state_group", current_search
)
- txn.execute(sql, list(current_search))
+ txn.execute(sql + clause, list(args))
referenced = set(sg for sg, in txn)
referenced_groups |= referenced
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
index 6587f31e2b..31ea6f917f 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/events_bg_updates.py
@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
@@ -67,6 +68,23 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
)
+ self.register_background_update_handler(
+ "redactions_received_ts", self._redactions_received_ts
+ )
+
+ # This index gets deleted in `event_fix_redactions_bytes` update
+ self.register_background_index_update(
+ "event_fix_redactions_bytes_create_index",
+ index_name="redactions_censored_redacts",
+ table="redactions",
+ columns=["redacts"],
+ where_clause="have_censored",
+ )
+
+ self.register_background_update_handler(
+ "event_fix_redactions_bytes", self._event_fix_redactions_bytes
+ )
+
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -308,12 +326,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
WHERE
- prev_event_id IN (%s)
- AND NOT events.outlier
- """ % (
- ",".join("?" for _ in to_check),
+ NOT events.outlier
+ AND
+ """
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "prev_event_id", to_check
)
- txn.execute(sql, to_check)
+ txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph:
@@ -397,3 +416,90 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
)
return num_handled
+
+ @defer.inlineCallbacks
+ def _redactions_received_ts(self, progress, batch_size):
+ """Handles filling out the `received_ts` column in redactions.
+ """
+ last_event_id = progress.get("last_event_id", "")
+
+ def _redactions_received_ts_txn(txn):
+ # Fetch the set of event IDs that we want to update
+ sql = """
+ SELECT event_id FROM redactions
+ WHERE event_id > ?
+ ORDER BY event_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ upper_event_id, = rows[-1]
+
+ # Update the redactions with the received_ts.
+ #
+ # Note: Not all events have an associated received_ts, so we
+ # fallback to using origin_server_ts. If we for some reason don't
+ # have an origin_server_ts, lets just use the current timestamp.
+ #
+ # We don't want to leave it null, as then we'll never try and
+ # censor those redactions.
+ sql = """
+ UPDATE redactions
+ SET received_ts = (
+ SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
+ WHERE events.event_id = redactions.event_id
+ )
+ WHERE ? <= event_id AND event_id <= ?
+ """
+
+ txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
+
+ self._background_update_progress_txn(
+ txn, "redactions_received_ts", {"last_event_id": upper_event_id}
+ )
+
+ return len(rows)
+
+ count = yield self.runInteraction(
+ "_redactions_received_ts", _redactions_received_ts_txn
+ )
+
+ if not count:
+ yield self._end_background_update("redactions_received_ts")
+
+ return count
+
+ @defer.inlineCallbacks
+ def _event_fix_redactions_bytes(self, progress, batch_size):
+ """Undoes hex encoded censored redacted event JSON.
+ """
+
+ def _event_fix_redactions_bytes_txn(txn):
+ # This update is quite fast due to new index.
+ txn.execute(
+ """
+ UPDATE event_json
+ SET
+ json = convert_from(json::bytea, 'utf8')
+ FROM redactions
+ WHERE
+ redactions.have_censored
+ AND event_json.event_id = redactions.redacts
+ AND json NOT LIKE '{%';
+ """
+ )
+
+ txn.execute("DROP INDEX redactions_censored_redacts")
+
+ yield self.runInteraction(
+ "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
+ )
+
+ yield self._end_background_update("event_fix_redactions_bytes")
+
+ return 1
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index c6fa7f82fd..4c4b76bd93 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -31,12 +31,11 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
from synapse.util.metrics import Measure
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
@@ -238,6 +237,20 @@ class EventsWorkerStore(SQLBaseStore):
# we have to recheck auth now.
if not allow_rejected and entry.event.type == EventTypes.Redaction:
+ if not hasattr(entry.event, "redacts"):
+ # A redacted redaction doesn't have a `redacts` key, in
+ # which case lets just withhold the event.
+ #
+ # Note: Most of the time if the redactions has been
+ # redacted we still have the un-redacted event in the DB
+ # and so we'll still see the `redacts` key. However, this
+ # isn't always true e.g. if we have censored the event.
+ logger.debug(
+ "Withholding redaction event %s as we don't have redacts key",
+ event_id,
+ )
+ continue
+
redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
@@ -609,10 +622,14 @@ class EventsWorkerStore(SQLBaseStore):
" rej.reason "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
+ " WHERE "
+ )
- txn.execute(sql, evs)
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "e.event_id", evs
+ )
+
+ txn.execute(sql + clause, args)
for row in txn:
event_id = row[0]
@@ -626,11 +643,11 @@ class EventsWorkerStore(SQLBaseStore):
}
# check for redactions
- redactions_sql = (
- "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
+ redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
+
+ clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
- txn.execute(redactions_sql, evs)
+ txn.execute(redactions_sql + clause, args)
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
@@ -739,10 +756,11 @@ class EventsWorkerStore(SQLBaseStore):
results = set()
def have_seen_events_txn(txn, chunk):
- sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
- ",".join("?" * len(chunk)),
+ sql = "SELECT event_id FROM events as e WHERE "
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "e.event_id", chunk
)
- txn.execute(sql, chunk)
+ txn.execute(sql + clause, args)
for (event_id,) in txn:
results.add(event_id)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 23b48f6cea..7c2a7da836 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -51,7 +51,7 @@ class FilteringStore(SQLBaseStore):
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
)
- txn.execute(sql, (user_localpart, def_json))
+ txn.execute(sql, (user_localpart, bytearray(def_json)))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
@@ -68,7 +68,7 @@ class FilteringStore(SQLBaseStore):
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
- txn.execute(sql, (user_localpart, filter_id, def_json))
+ txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 6b1238ce4a..84b5f3ad5e 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -15,11 +15,9 @@
from synapse.storage.background_updates import BackgroundUpdateStore
-class MediaRepositoryStore(BackgroundUpdateStore):
- """Persistence for attachments and avatars"""
-
+class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(db_conn, hs)
+ super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
update_name="local_media_repository_url_idx",
@@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause="url_cache IS NOT NULL",
)
+
+class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
+ """Persistence for attachments and avatars"""
+
+ def __init__(self, db_conn, hs):
+ super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 752e9788a2..3803604be7 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -32,7 +32,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock()
self.hs = hs
- self.reserved_users = ()
# Do not add more reserved users than the total allowable number
self._new_transaction(
dbconn,
@@ -51,7 +50,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
- reserved_user_list = []
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -60,10 +58,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
self.upsert_monthly_active_user_txn(txn, user_id)
- reserved_user_list.append(user_id)
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
- self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
@@ -74,8 +70,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[]
"""
- def _reap_users(txn):
- # Purge stale users
+ def _reap_users(txn, reserved_users):
+ """
+ Args:
+ reserved_users (tuple): reserved users to preserve
+ """
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
@@ -83,20 +82,19 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
- if len(self.reserved_users) > 0:
+ if len(reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
- questionmarks = "?" * len(self.reserved_users)
+ question_marks = ",".join("?" * len(reserved_users))
- query_args.extend(self.reserved_users)
- sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ",".join(questionmarks)
- )
+ query_args.extend(reserved_users)
+ sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
else:
sql = base_sql
txn.execute(sql, query_args)
+ max_mau_value = self.hs.config.max_mau_value
if self.hs.config.limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
@@ -106,31 +104,52 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
- safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
- # Must be greater than zero for postgres
- safe_guard = safe_guard if safe_guard > 0 else 0
- query_args = [safe_guard]
-
- base_sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- ORDER BY timestamp DESC
- LIMIT ?
+ if len(reserved_users) == 0:
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ ORDER BY timestamp DESC
+ LIMIT ?
)
- """
+ """
+ txn.execute(sql, (max_mau_value,))
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
- if len(self.reserved_users) > 0:
- query_args.extend(self.reserved_users)
- sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ",".join(questionmarks)
- )
else:
- sql = base_sql
- txn.execute(sql, query_args)
+ # Must be >= 0 for postgres
+ num_of_non_reserved_users_to_remove = max(
+ max_mau_value - len(reserved_users), 0
+ )
+
+ # It is important to filter reserved users twice to guard
+ # against the case where the reserved user is present in the
+ # SELECT, meaning that a legitmate mau is deleted.
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ WHERE user_id NOT IN ({})
+ ORDER BY timestamp DESC
+ LIMIT ?
+ )
+ AND user_id NOT IN ({})
+ """.format(
+ question_marks, question_marks
+ )
+
+ query_args = [
+ *reserved_users,
+ num_of_non_reserved_users_to_remove,
+ *reserved_users,
+ ]
- yield self.runInteraction("reap_monthly_active_users", _reap_users)
+ txn.execute(sql, query_args)
+
+ reserved_users = yield self.get_registered_reserved_users()
+ yield self.runInteraction(
+ "reap_monthly_active_users", _reap_users, reserved_users
+ )
# It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
@@ -159,21 +178,25 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
- def get_registered_reserved_users_count(self):
- """Of the reserved threepids defined in config, how many are associated
+ def get_registered_reserved_users(self):
+ """Of the reserved threepids defined in config, which are associated
with registered users?
Returns:
- Defered[int]: Number of real reserved users
+ Defered[list]: Real reserved users
"""
- count = 0
- for tp in self.hs.config.mau_limits_reserved_threepids:
+ users = []
+
+ for tp in self.hs.config.mau_limits_reserved_threepids[
+ : self.hs.config.max_mau_value
+ ]:
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
- count = count + 1
- return count
+ users.append(user_id)
+
+ return users
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 5db6f2d84a..3a641f538b 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -18,11 +18,10 @@ from collections import namedtuple
from twisted.internet import defer
from synapse.api.constants import PresenceState
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedList
-from ._base import SQLBaseStore
-
class UserPresenceState(
namedtuple(
@@ -119,14 +118,13 @@ class PresenceStore(SQLBaseStore):
)
# Delete old rows to stop database from getting really big
- sql = (
- "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
- )
+ sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
for states in batch_iter(presence_states, 50):
- args = [stream_id]
- args.extend(s.user_id for s in states)
- txn.execute(sql % (",".join("?" for _ in states),), args)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", [s.user_id for s in states]
+ )
+ txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id:
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index a6517c4cf3..c4e24edff2 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -183,8 +183,8 @@ class PushRulesWorkerStore(
return results
@defer.inlineCallbacks
- def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
- """Move a single push rule from one room to another for a specific user.
+ def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ """Copy a single push rule from one room to another for a specific user.
Args:
new_room_id (str): ID of the new room.
@@ -209,14 +209,11 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
- # Delete push rule for the old room
- yield self.delete_push_rule(user_id, rule["rule_id"])
-
@defer.inlineCallbacks
- def move_push_rules_from_room_to_room_for_user(
+ def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id
):
- """Move all of the push rules from one room to another for a specific
+ """Copy all of the push rules from one room to another for a specific
user.
Args:
@@ -227,15 +224,14 @@ class PushRulesWorkerStore(
# Retrieve push rules for this user
user_push_rules = yield self.get_push_rules_for_user(user_id)
- # Get rules relating to the old room, move them to the new room, then
- # delete them from the old room
+ # Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
- self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
+ yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 3e0e834a62..b12e80440a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -241,7 +241,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
- "data": encode_canonical_json(data),
+ "data": bytearray(encode_canonical_json(data)),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 290ddb30e8..0c24430f28 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -21,12 +21,11 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import SQLBaseStore
-from .util.id_generators import StreamIdGenerator
-
logger = logging.getLogger(__name__)
@@ -217,24 +216,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn):
if from_key:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
- ) % (",".join(["?"] * len(room_ids)))
- args = list(room_ids)
- args.extend([from_key, to_key])
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id > ? AND stream_id <= ? AND
+ """
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
- txn.execute(sql, args)
+ txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id IN (%s) AND stream_id <= ?"
- ) % (",".join(["?"] * len(room_ids)))
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id <= ? AND
+ """
- args = list(room_ids)
- args.append(to_key)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
- txn.execute(sql, args)
+ txn.execute(sql + clause, [to_key] + list(args))
return self.cursor_to_dict(txn)
@@ -433,13 +434,19 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
- query = (
- "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
- " SELECT max(stream_ordering) WHERE event_id IN (%s)"
- ")"
- ) % (",".join(["?"] * len(event_ids)))
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", event_ids
+ )
+
+ sql = """
+ SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+ SELECT max(stream_ordering) WHERE %s
+ )
+ """ % (
+ clause,
+ )
- txn.execute(query, [room_id] + event_ids)
+ txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 5138792a5f..6c5b29288a 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,9 +22,10 @@ from six import iterkeys
from six.moves import range
from twisted.internet import defer
+from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, ThreepidValidationError
+from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
@@ -323,6 +324,19 @@ class RegistrationWorkerStore(SQLBaseStore):
return None
@cachedInlineCallbacks()
+ def is_real_user(self, user_id):
+ """Determines if the user is a real user, ie does not have a 'user_type'.
+
+ Args:
+ user_id (str): user id to test
+
+ Returns:
+ Deferred[bool]: True if user 'user_type' is null or empty string
+ """
+ res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+ return res
+
+ @cachedInlineCallbacks()
def is_support_user(self, user_id):
"""Determines if the user is of type UserTypes.SUPPORT
@@ -337,6 +351,16 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return res
+ def is_real_user_txn(self, txn, user_id):
+ res = self._simple_select_one_onecol_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="user_type",
+ allow_none=True,
+ )
+ return res is None
+
def is_support_user_txn(self, txn, user_id):
res = self._simple_select_one_onecol_txn(
txn=txn,
@@ -361,6 +385,26 @@ class RegistrationWorkerStore(SQLBaseStore):
return self.runInteraction("get_users_by_id_case_insensitive", f)
+ async def get_user_by_external_id(
+ self, auth_provider: str, external_id: str
+ ) -> str:
+ """Look up a user by their external auth id
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+
+ Returns:
+ str|None: the mxid of the user, or None if they are not known
+ """
+ return await self._simple_select_one_onecol(
+ table="user_external_ids",
+ keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_user_by_external_id",
+ )
+
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -422,6 +466,20 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret
@defer.inlineCallbacks
+ def count_real_users(self):
+ """Counts all users without a special user_type registered on the homeserver."""
+
+ def _count_users(txn):
+ txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
+ rows = self.cursor_to_dict(txn)
+ if rows:
+ return rows[0]["users"]
+ return 0
+
+ ret = yield self.runInteraction("count_real_users", _count_users)
+ return ret
+
+ @defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
@@ -435,7 +493,9 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def _find_next_generated_user_id(txn):
- txn.execute("SELECT name FROM users")
+ # We bound between '@1' and '@a' to avoid pulling the entire table
+ # out.
+ txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
@@ -458,7 +518,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address, require_verified=False):
+ def get_user_id_by_threepid(self, medium, address):
"""Returns user id from threepid
Args:
@@ -549,6 +609,26 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
+ def user_get_bound_threepids(self, user_id):
+ """Get the threepids that a user has bound to an identity server through the homeserver
+ The homeserver remembers where binds to an identity server occurred. Using this
+ method can retrieve those threepids.
+
+ Args:
+ user_id (str): The ID of the user to retrieve threepids for
+
+ Returns:
+ Deferred[list[dict]]: List of dictionaries containing the following:
+ medium (str): The medium of the threepid (e.g "email")
+ address (str): The address of the threepid (e.g "bob@example.com")
+ """
+ return self._simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ )
+
def remove_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied an unbind request to the given identity server on
behalf of the given user, so we remove the mapping of threepid to
@@ -618,24 +698,37 @@ class RegistrationWorkerStore(SQLBaseStore):
self, medium, client_secret, address=None, sid=None, validated=True
):
"""Gets a session_id and last_send_attempt (if available) for a
- client_secret/medium/(address|session_id) combo
+ combination of validation metadata
Args:
medium (str|None): The medium of the 3PID
address (str|None): The address of the 3PID
sid (str|None): The ID of the validation session
- client_secret (str|None): A unique string provided by the client to
- help identify this validation attempt
+ client_secret (str): A unique string provided by the client to help identify this
+ validation attempt
validated (bool|None): Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
- deferred {str, int}|None: A dict containing the
- latest session_id and send_attempt count for this 3PID.
- Otherwise None if there hasn't been a previous attempt
+ Deferred[dict|None]: A dict containing the following:
+ * address - address of the 3pid
+ * medium - medium of the 3pid
+ * client_secret - a secret provided by the client for this validation session
+ * session_id - ID of the validation session
+ * send_attempt - a number serving to dedupe send attempts for this session
+ * validated_at - timestamp of when this session was validated if so
+
+ Otherwise None if a validation session is not found
"""
- keyvalues = {"medium": medium, "client_secret": client_secret}
+ if not client_secret:
+ raise SynapseError(
+ 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
+ )
+
+ keyvalues = {"client_secret": client_secret}
+ if medium:
+ keyvalues["medium"] = medium
if address:
keyvalues["address"] = address
if sid:
@@ -694,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
)
-class RegistrationStore(
+class RegistrationBackgroundUpdateStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
):
def __init__(self, db_conn, hs):
- super(RegistrationStore, self).__init__(db_conn, hs)
+ super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
+ self.config = hs.config
self.register_background_index_update(
"access_tokens_device_index",
@@ -716,8 +810,6 @@ class RegistrationStore(
columns=["creation_ts"],
)
- self._account_validity = hs.config.account_validity
-
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@@ -731,17 +823,6 @@ class RegistrationStore(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
-
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
@defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@@ -774,7 +855,7 @@ class RegistrationStore(
rows = self.cursor_to_dict(txn)
if not rows:
- return True
+ return True, 0
rows_processed_nb = 0
@@ -790,18 +871,66 @@ class RegistrationStore(
)
if batch_size > len(rows):
- return True
+ return True, len(rows)
else:
- return False
+ return False, len(rows)
- end = yield self.runInteraction(
+ end, nb_processed = yield self.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
yield self._end_background_update("users_set_deactivated_flag")
- return batch_size
+ return nb_processed
+
+ @defer.inlineCallbacks
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
+ )
+
+ yield self._end_background_update("user_threepids_grandfather")
+
+ return 1
+
+
+class RegistrationStore(RegistrationBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(RegistrationStore, self).__init__(db_conn, hs)
+
+ self._account_validity = hs.config.account_validity
+
+ # Create a background job for culling expired 3PID validity tokens
+ def start_cull():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "cull_expired_threepid_validation_tokens",
+ self.cull_expired_threepid_validation_tokens,
+ )
+
+ hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -962,6 +1091,26 @@ class RegistrationStore(
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ def record_user_external_id(
+ self, auth_provider: str, external_id: str, user_id: str
+ ) -> Deferred:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ return self._simple_insert(
+ table="user_external_ids",
+ values={
+ "auth_provider": auth_provider,
+ "external_id": external_id,
+ "user_id": user_id,
+ },
+ desc="record_user_external_id",
+ )
+
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
@@ -1131,36 +1280,6 @@ class RegistrationStore(
desc="get_users_pending_deactivation",
)
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- yield self.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
- )
-
- yield self._end_background_update("user_threepids_grandfather")
-
- return 1
-
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
@@ -1172,6 +1291,10 @@ class RegistrationStore(
current_ts (int): The current unix time in milliseconds. Used for
checking token expiry status
+ Raises:
+ ThreepidValidationError: if a matching validation token was not found or has
+ expired
+
Returns:
deferred str|None: A str representing a link to redirect the user
to if there is one.
@@ -1348,17 +1471,6 @@ class RegistrationStore(
self.clock.time_msec(),
)
- def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self._simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"deactivated": 1 if deactivated else 0},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,)
- )
-
@defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated):
"""Set the `deactivated` property for the provided user to the provided value.
@@ -1374,3 +1486,14 @@ class RegistrationStore(
user_id,
deactivated,
)
+
+ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ self._simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
+ )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 08e13f3a3b..43cc56fa6f 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import collections
import logging
import re
+from typing import Optional, Tuple
from canonicaljson import json
@@ -24,6 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
+from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -63,103 +66,196 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
+ def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ """Counts the number of public rooms as tracked in the room_stats_current
+ and room_stats_state table.
Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
+ network_tuple (ThirdPartyInstanceID|None)
+ ignore_non_federatable (bool): If true filters out non-federatable rooms
"""
- return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id,
- network_tuple=network_tuple,
- )
-
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
+ def _count_public_rooms_txn(txn):
+ query_args = []
+
+ if network_tuple:
+ if network_tuple.appservice_id:
+ published_sql = """
+ SELECT room_id from appservice_room_list
+ WHERE appservice_id = ? AND network_id = ?
+ """
+ query_args.append(network_tuple.appservice_id)
+ query_args.append(network_tuple.network_id)
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ """
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ UNION SELECT room_id from appservice_room_list
+ """
sql = """
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
+ SELECT
+ COALESCE(COUNT(*), 0)
+ FROM (
+ %(published_sql)s
+ ) published
+ INNER JOIN room_stats_state USING (room_id)
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ (
+ join_rules = 'public' OR history_visibility = 'world_readable'
+ )
+ AND joined_members > 0
+ """ % {
+ "published_sql": published_sql
+ }
+
+ txn.execute(sql, query_args)
+ return txn.fetchone()[0]
+
+ return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
+
+ @defer.inlineCallbacks
+ def get_largest_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ search_filter: Optional[dict],
+ limit: Optional[int],
+ bounds: Optional[Tuple[int, str]],
+ forwards: bool,
+ ignore_non_federatable: bool = False,
+ ):
+ """Gets the largest public rooms (where largest is in terms of joined
+ members, as tracked in the statistics table).
+
+ Args:
+ network_tuple
+ search_filter
+ limit: Maxmimum number of rows to return, unlimited otherwise.
+ bounds: An uppoer or lower bound to apply to result set if given,
+ consists of a joined member count and room_id (these are
+ excluded from result set).
+ forwards: true iff going forwards, going backwards otherwise
+ ignore_non_federatable: If true filters out non-federatable rooms.
+
+ Returns:
+ Rooms in order: biggest number of joined users first.
+ We then arbitrarily use the room_id as a tie breaker.
+
+ """
+
+ where_clauses = []
+ query_args = []
+
+ if network_tuple:
+ if network_tuple.appservice_id:
+ published_sql = """
+ SELECT room_id from appservice_room_list
+ WHERE appservice_id = ? AND network_id = ?
+ """
+ query_args.append(network_tuple.appservice_id)
+ query_args.append(network_tuple.network_id)
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ """
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ UNION SELECT room_id from appservice_room_list
"""
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id),
+ # Work out the bounds if we're given them, these bounds look slightly
+ # odd, but are designed to help query planner use indices by pulling
+ # out a common bound.
+ if bounds:
+ last_joined_members, last_room_id = bounds
+ if forwards:
+ where_clauses.append(
+ """
+ joined_members <= ? AND (
+ joined_members < ? OR room_id < ?
+ )
+ """
)
else:
- txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
- return dict(txn)
- else:
- # We want to get from all lists, so we need to aggregate the results
+ where_clauses.append(
+ """
+ joined_members >= ? AND (
+ joined_members > ? OR room_id > ?
+ )
+ """
+ )
- logger.info("Executing full list")
+ query_args += [last_joined_members, last_joined_members, last_room_id]
- sql = """
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """
+ if ignore_non_federatable:
+ where_clauses.append("is_federatable")
- txn.execute(sql, (stream_id,))
+ if search_filter and search_filter.get("generic_search_term", None):
+ search_term = "%" + search_filter["generic_search_term"] + "%"
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn:
- results[room_id] = bool(visibility) or results.get(room_id, False)
+ where_clauses.append(
+ """
+ (
+ name LIKE ?
+ OR topic LIKE ?
+ OR canonical_alias LIKE ?
+ )
+ """
+ )
+ query_args += [search_term, search_term, search_term]
+
+ where_clause = ""
+ if where_clauses:
+ where_clause = " AND " + " AND ".join(where_clauses)
+
+ sql = """
+ SELECT
+ room_id, name, topic, canonical_alias, joined_members,
+ avatar, history_visibility, joined_members, guest_access
+ FROM (
+ %(published_sql)s
+ ) published
+ INNER JOIN room_stats_state USING (room_id)
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ (
+ join_rules = 'public' OR history_visibility = 'world_readable'
+ )
+ AND joined_members > 0
+ %(where_clause)s
+ ORDER BY joined_members %(dir)s, room_id %(dir)s
+ """ % {
+ "published_sql": published_sql,
+ "where_clause": where_clause,
+ "dir": "DESC" if forwards else "ASC",
+ }
- return results
+ if limit is not None:
+ query_args.append(limit)
- def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
+ sql += """
+ LIMIT ?
+ """
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
+ def _get_largest_public_rooms_txn(txn):
+ txn.execute(sql, query_args)
- now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
- )
+ results = self.cursor_to_dict(txn)
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
+ if not forwards:
+ results.reverse()
- return newly_visible, newly_unpublished
+ return results
- return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
+ ret_val = yield self.runInteraction(
+ "get_largest_public_rooms", _get_largest_public_rooms_txn
)
+ defer.returnValue(ret_val)
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4df8ebdacd..ff63487823 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -26,13 +26,15 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction
+from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.metrics import Measure
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -370,6 +372,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = []
if membership_list:
if self._current_state_events_membership_up_to_date:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "c.membership", membership_list
+ )
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c
@@ -377,11 +382,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE
c.type = 'm.room.member'
AND state_key = ?
- AND c.membership IN (%s)
+ AND %s
""" % (
- ",".join("?" * len(membership_list))
+ clause,
)
else:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "m.membership", membership_list
+ )
sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c
@@ -390,12 +398,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE
c.type = 'm.room.member'
AND state_key = ?
- AND m.membership IN (%s)
+ AND %s
""" % (
- ",".join("?" * len(membership_list))
+ clause,
)
- txn.execute(sql, (user_id, *membership_list))
+ txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite:
@@ -483,6 +491,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return result
+ @defer.inlineCallbacks
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
@@ -492,9 +501,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
+ with Measure(self._clock, "get_joined_users_from_state"):
+ return (
+ yield self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry
+ )
+ )
@cachedInlineCallbacks(
num_args=2, cache_context=True, iterable=True, max_entries=100000
@@ -567,25 +579,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
- rows = yield self._simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=missing_member_event_ids,
- retcols=("user_id", "display_name", "avatar_url"),
- keyvalues={"membership": Membership.JOIN},
- batch_size=500,
- desc="_get_joined_users_from_context",
- )
-
- users_in_room.update(
- {
- to_ascii(row["user_id"]): ProfileInfo(
- avatar_url=to_ascii(row["avatar_url"]),
- display_name=to_ascii(row["display_name"]),
- )
- for row in rows
- }
+ event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+ missing_member_event_ids
)
+ users_in_room.update((row for row in event_to_memberships.values() if row))
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -597,6 +594,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
+ @cached(max_entries=10000)
+ def _get_joined_profile_from_event_id(self, event_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_joined_profile_from_event_id",
+ list_name="event_ids",
+ inlineCallbacks=True,
+ )
+ def _get_joined_profiles_from_event_ids(self, event_ids):
+ """For given set of member event_ids check if they point to a join
+ event and if so return the associated user and profile info.
+
+ Args:
+ event_ids (Iterable[str]): The member event IDs to lookup
+
+ Returns:
+ Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ to `user_id` and ProfileInfo (or None if not join event).
+ """
+
+ rows = yield self._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=500,
+ desc="_get_membership_from_event_ids",
+ )
+
+ return {
+ row["event_id"]: (
+ row["user_id"],
+ ProfileInfo(
+ avatar_url=row["avatar_url"], display_name=row["display_name"]
+ ),
+ )
+ for row in rows
+ }
+
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
if "%" in host or "_" in host:
@@ -669,6 +707,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
+ @defer.inlineCallbacks
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
@@ -678,9 +717,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
- )
+ with Measure(self._clock, "get_joined_hosts"):
+ return (
+ yield self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry
+ )
+ )
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
# @defer.inlineCallbacks
@@ -785,9 +827,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
-class RoomMemberStore(RoomMemberWorkerStore):
+class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
- super(RoomMemberStore, self).__init__(db_conn, hs)
+ super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@@ -803,112 +845,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
where_clause="forgotten = 1",
)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ],
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key,
- event.internal_metadata.stream_ordering,
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
- def forget(self, user_id, room_id):
- """Indicate that user_id wishes to discard history for room_id."""
-
- def f(txn):
- sql = (
- "UPDATE"
- " room_memberships"
- " SET"
- " forgotten = 1"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- )
- txn.execute(sql, (user_id, room_id))
-
- self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
- self._invalidate_cache_and_stream(
- txn, self.get_forgotten_rooms_for_user, (user_id,)
- )
-
- return self.runInteraction("forget_membership", f)
-
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@@ -1043,6 +979,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
return row_count
+class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberStore, self).__init__(db_conn, hs)
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ],
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key,
+ event.internal_metadata.stream_ordering,
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_out_of_band_membership()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ },
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(
+ sql,
+ (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ),
+ )
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (stream_ordering, True, room_id, user_id))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+
+ self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_forgotten_rooms_for_user, (user_id,)
+ )
+
+ return self.runInteraction("forget_membership", f)
+
+
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.
diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/schema/delta/56/destinations_failure_ts.sql
new file mode 100644
index 0000000000..f00889290b
--- /dev/null
+++ b/synapse/storage/schema/delta/56/destinations_failure_ts.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Record the timestamp when a given server started failing
+ */
+ALTER TABLE destinations ADD failure_ts BIGINT;
+
+/* as a rough approximation, we assume that the server started failing at
+ * retry_interval before the last retry
+ */
+UPDATE destinations SET failure_ts = retry_last_ts - retry_interval
+ WHERE retry_last_ts > 0;
diff --git a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres
new file mode 100644
index 0000000000..b9bbb18a91
--- /dev/null
+++ b/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We want to store large retry intervals so we upgrade the column from INT
+-- to BIGINT. We don't need to do this on SQLite.
+ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT;
diff --git a/synapse/storage/schema/delta/56/devices_last_seen.sql b/synapse/storage/schema/delta/56/devices_last_seen.sql
new file mode 100644
index 0000000000..dfa902d0ba
--- /dev/null
+++ b/synapse/storage/schema/delta/56/devices_last_seen.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track last seen information for a device in the devices table, rather
+-- than relying on it being in the user_ips table (which we want to be able
+-- to purge old entries from)
+ALTER TABLE devices ADD COLUMN last_seen BIGINT;
+ALTER TABLE devices ADD COLUMN ip TEXT;
+ALTER TABLE devices ADD COLUMN user_agent TEXT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('devices_last_seen', '{}');
diff --git a/synapse/storage/schema/delta/56/public_room_list_idx.sql b/synapse/storage/schema/delta/56/public_room_list_idx.sql
new file mode 100644
index 0000000000..7be31ffebb
--- /dev/null
+++ b/synapse/storage/schema/delta/56/public_room_list_idx.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);
diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/schema/delta/56/redaction_censor.sql
new file mode 100644
index 0000000000..fe51b02309
--- /dev/null
+++ b/synapse/storage/schema/delta/56/redaction_censor.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
+CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/schema/delta/56/redaction_censor2.sql b/synapse/storage/schema/delta/56/redaction_censor2.sql
new file mode 100644
index 0000000000..77a5eca499
--- /dev/null
+++ b/synapse/storage/schema/delta/56/redaction_censor2.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
+CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('redactions_received_ts', '{}');
diff --git a/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres
new file mode 100644
index 0000000000..67471f3ef5
--- /dev/null
+++ b/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- There was a bug where we may have updated censored redactions as bytes,
+-- which can (somehow) cause json to be inserted hex encoded. These updates go
+-- and undoes any such hex encoded JSON.
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_fix_redactions_bytes_create_index', '{}');
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+ VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index');
diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/schema/delta/56/unique_user_filter_index.py
new file mode 100644
index 0000000000..1de8b54961
--- /dev/null
+++ b/synapse/storage/schema/delta/56/unique_user_filter_index.py
@@ -0,0 +1,52 @@
+import logging
+
+from synapse.storage.engines import PostgresEngine
+
+logger = logging.getLogger(__name__)
+
+
+"""
+This migration updates the user_filters table as follows:
+
+ - drops any (user_id, filter_id) duplicates
+ - makes the columns NON-NULLable
+ - turns the index into a UNIQUE index
+"""
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ select_clause = """
+ SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
+ FROM user_filters
+ """
+ else:
+ select_clause = """
+ SELECT * FROM user_filters GROUP BY user_id, filter_id
+ """
+ sql = """
+ DROP TABLE IF EXISTS user_filters_migration;
+ DROP INDEX IF EXISTS user_filters_unique;
+ CREATE TABLE user_filters_migration (
+ user_id TEXT NOT NULL,
+ filter_id BIGINT NOT NULL,
+ filter_json BYTEA NOT NULL
+ );
+ INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
+ %s;
+ CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
+ (user_id, filter_id);
+ DROP TABLE user_filters;
+ ALTER TABLE user_filters_migration RENAME TO user_filters;
+ """ % (
+ select_clause,
+ )
+
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(sql)
+ else:
+ cur.executescript(sql)
diff --git a/synapse/storage/schema/delta/56/user_external_ids.sql b/synapse/storage/schema/delta/56/user_external_ids.sql
new file mode 100644
index 0000000000..91390c4527
--- /dev/null
+++ b/synapse/storage/schema/delta/56/user_external_ids.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * a table which records mappings from external auth providers to mxids
+ */
+CREATE TABLE IF NOT EXISTS user_external_ids (
+ auth_provider TEXT NOT NULL,
+ external_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ UNIQUE (auth_provider, external_id)
+);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index df87ab6a6d..7695bf09fc 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -24,6 +24,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from .background_updates import BackgroundUpdateStore
@@ -36,7 +37,7 @@ SearchEntry = namedtuple(
)
-class SearchStore(BackgroundUpdateStore):
+class SearchBackgroundUpdateStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -44,7 +45,7 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs):
- super(SearchStore, self).__init__(db_conn, hs)
+ super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
if not hs.config.enable_search:
return
@@ -289,29 +290,6 @@ class SearchStore(BackgroundUpdateStore):
return num_rows
- def store_event_search_txn(self, txn, event, key, value):
- """Add event to the search table
-
- Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
- """
- self.store_search_entries_txn(
- txn,
- (
- SearchEntry(
- key=key,
- value=value,
- event_id=event.event_id,
- room_id=event.room_id,
- stream_ordering=event.internal_metadata.stream_ordering,
- origin_server_ts=event.origin_server_ts,
- ),
- ),
- )
-
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
@@ -358,6 +336,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
+
+class SearchStore(SearchBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(SearchStore, self).__init__(db_conn, hs)
+
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),
+ ),
+ )
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -380,8 +386,10 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
- args.extend(room_ids)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+ clauses = [clause]
local_clauses = []
for key in keys:
@@ -487,8 +495,10 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
- args.extend(room_ids)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+ clauses = [clause]
local_clauses = []
for key in keys:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 1980a87108..a941a5ae3f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -353,8 +353,158 @@ class StateFilter(object):
return member_filter, non_member_filter
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+ """Defines functions related to state groups needed to run the state backgroud
+ updates.
+ """
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+ def _get_state_groups_from_groups_txn(
+ self, txn, groups, state_filter=StateFilter.all()
+ ):
+ results = {group: {} for group in groups}
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # Temporarily disable sequential scans in this transaction. This is
+ # a temporary hack until we can add the right indices in
+ txn.execute("SET LOCAL enable_seqscan=off")
+
+ # The below query walks the state_group tree so that the "state"
+ # table includes all state_groups in the tree. It then joins
+ # against `state_groups_state` to fetch the latest state.
+ # It assumes that previous state groups are always numerically
+ # lesser.
+ # The PARTITION is used to get the event_id in the greatest state
+ # group for the given type, state_key.
+ # This may return multiple rows per (type, state_key), but last_value
+ # should be the same.
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+ PARTITION BY type, state_key ORDER BY state_group ASC
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS event_id FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM state
+ )
+ """
+
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
+
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
+ else:
+ max_entries_returned = state_filter.max_entries_returned()
+
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ for group in groups:
+ next_group = group
+
+ while next_group:
+ # We did this before by getting the list of group ids, and
+ # then passing that list to sqlite to get latest event for
+ # each (type, state_key). However, that was terribly slow
+ # without the right indices (which we can't add until
+ # after we finish deduping state, which requires this func)
+ args = [next_group]
+ args.extend(where_args)
+
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? " + where_clause,
+ args,
+ )
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
+ if (typ, state_key) not in results[group]
+ )
+
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
+ ):
+ break
+
+ next_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ return results
+
+
# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StateGroupWorkerStore(
+ EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
+):
"""The parts of StateGroupStore that can be called from workers.
"""
@@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
- ):
- results = {group: {} for group in groups}
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
- if isinstance(self.database_engine, PostgresEngine):
- # Temporarily disable sequential scans in this transaction. This is
- # a temporary hack until we can add the right indices in
- txn.execute("SET LOCAL enable_seqscan=off")
-
- # The below query walks the state_group tree so that the "state"
- # table includes all state_groups in the tree. It then joins
- # against `state_groups_state` to fetch the latest state.
- # It assumes that previous state groups are always numerically
- # lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
- # This may return multiple rows per (type, state_key), but last_value
- # should be the same.
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- )
- """
-
- for group in groups:
- args = [group]
- args.extend(where_args)
-
- txn.execute(sql + where_clause, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
- else:
- max_entries_returned = state_filter.max_entries_returned()
-
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- for group in groups:
- next_group = group
-
- while next_group:
- # We did this before by getting the list of group ids, and
- # then passing that list to sqlite to get latest event for
- # each (type, state_key). However, that was terribly slow
- # without the right indices (which we can't add until
- # after we finish deduping state, which requires this func)
- args = [next_group]
- args.extend(where_args)
-
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? " + where_clause,
- args,
- )
- results[group].update(
- ((typ, state_key), event_id)
- for typ, state_key, event_id in txn
- if (typ, state_key) not in results[group]
- )
-
- # If the number of entries in the (type,state_key)->event_id dict
- # matches the number of (type,state_keys) types we were searching
- # for, then we must have found them all, so no need to go walk
- # further down the tree... UNLESS our types filter contained
- # wildcards (i.e. Nones) in which case we have to do an exhaustive
- # search
- if (
- max_entries_returned is not None
- and len(results[group]) == max_entries_returned
- ):
- break
-
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- return results
-
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
@@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return self.runInteraction("store_state_group", _store_state_group_txn)
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
-
-class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
- """ Keeps track of the state at a given event.
-
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
-
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
- """
+class StateBackgroundUpdateStore(
+ StateGroupBackgroundUpdateStore, BackgroundUpdateStore
+):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
@@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
+ super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
@@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_group"],
)
- def _store_event_state_mappings_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.prev_group
- continue
-
- state_groups[event.event_id] = context.state_group
-
- self._simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
- ],
- )
-
- for event_id, state_group_id in iteritems(state_groups):
- txn.call_after(
- self._get_state_group_for_event.prefill, (event_id,), state_group_id
- )
-
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
@@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1
+
+
+class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ def __init__(self, db_conn, hs):
+ super(StateStore, self).__init__(db_conn, hs)
+
+ def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
+ self._simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {"state_group": state_group_id, "event_id": event_id}
+ for event_id, state_group_id in iteritems(state_groups)
+ ],
+ )
+
+ for event_id, state_group_id in iteritems(state_groups):
+ txn.call_after(
+ self._get_state_group_for_event.prefill, (event_id,), state_group_id
+ )
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 5fdb442104..28f33ec18f 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id):
+ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore):
Args:
prev_stream_id (int): point to get changes since (exclusive)
+ max_stream_id (int): the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
Returns:
- Deferred[list[dict]]: results
+ Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
+
+ # check we're not going backwards
+ assert prev_stream_id <= max_stream_id
+
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
- return []
+ # if the CSDs haven't changed between prev_stream_id and now, we
+ # know for certain that they haven't changed between prev_stream_id and
+ # max_stream_id.
+ return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
- WHERE stream_id > ?
+ WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
- txn.execute(sql, (prev_stream_id,))
+ txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0
- max_stream_id = prev_stream_id
- for max_stream_id, count in txn:
+
+ for stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
+ logger.debug(
+ "Clipping current_state_delta_stream rows to stream_id %i",
+ stream_id,
+ )
+ clipped_stream_id = stream_id
break
+ else:
+ # if there's no problem, we may as well go right up to the max_stream_id
+ clipped_stream_id = max_stream_id
# Now actually get the deltas
sql = """
@@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
- txn.execute(sql, (prev_stream_id, max_stream_id))
- return self.cursor_to_dict(txn)
+ txn.execute(sql, (prev_stream_id, clipped_stream_id))
+ return clipped_stream_id, self.cursor_to_dict(txn)
return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
index 09190d684e..7c224cd3d9 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/stats.py
@@ -332,6 +332,9 @@ class StatsStore(StateDeltasStore):
def _bulk_update_stats_delta_txn(txn):
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
+ logger.info(
+ "Updating %s stats for %s: %s", stats_type, stats_id, fields
+ )
self._update_stats_delta_txn(
txn,
ts=ts,
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index b3c3bf55bc..289c117396 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -165,7 +165,7 @@ class TransactionStore(SQLBaseStore):
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("destination", "retry_last_ts", "retry_interval"),
+ retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -174,12 +174,15 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
+ def set_destination_retry_timings(
+ self, destination, failure_ts, retry_last_ts, retry_interval
+ ):
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
destination (str)
+ failure_ts (int|None) - when the server started failing (ms since epoch)
retry_last_ts (int) - time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms
"""
@@ -189,12 +192,13 @@ class TransactionStore(SQLBaseStore):
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
+ failure_ts,
retry_last_ts,
retry_interval,
)
def _set_destination_retry_timings(
- self, txn, destination, retry_last_ts, retry_interval
+ self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
if self.database_engine.can_native_upsert:
@@ -202,9 +206,12 @@ class TransactionStore(SQLBaseStore):
# resetting it) or greater than the existing retry interval.
sql = """
- INSERT INTO destinations (destination, retry_last_ts, retry_interval)
- VALUES (?, ?, ?)
+ INSERT INTO destinations (
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ VALUES (?, ?, ?, ?)
ON CONFLICT (destination) DO UPDATE SET
+ failure_ts = EXCLUDED.failure_ts,
retry_last_ts = EXCLUDED.retry_last_ts,
retry_interval = EXCLUDED.retry_interval
WHERE
@@ -212,7 +219,7 @@ class TransactionStore(SQLBaseStore):
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
- txn.execute(sql, (destination, retry_last_ts, retry_interval))
+ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
return
@@ -225,7 +232,7 @@ class TransactionStore(SQLBaseStore):
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("retry_last_ts", "retry_interval"),
+ retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -235,6 +242,7 @@ class TransactionStore(SQLBaseStore):
table="destinations",
values={
"destination": destination,
+ "failure_ts": failure_ts,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
@@ -245,31 +253,12 @@ class TransactionStore(SQLBaseStore):
"destinations",
keyvalues={"destination": destination},
updatevalues={
+ "failure_ts": failure_ts,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
)
- def get_destinations_needing_retry(self):
- """Get all destinations which are due a retry for sending a transaction.
-
- Returns:
- list: A list of dicts
- """
-
- return self.runInteraction(
- "get_destinations_needing_retry", self._get_destinations_needing_retry
- )
-
- def _get_destinations_needing_retry(self, txn):
- query = (
- "SELECT * FROM destinations"
- " WHERE retry_last_ts > 0 and retry_next_ts < ?"
- )
-
- txn.execute(query, (self._clock.time_msec(),))
- return self.cursor_to_dict(txn)
-
def _start_cleanup_transactions(self):
return run_as_background_process(
"cleanup_transactions", self._cleanup_transactions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index b5188d9bee..1b1e4751b9 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -32,14 +32,14 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory"
-class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
+class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs):
- super(UserDirectoryStore, self).__init__(db_conn, hs)
+ super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.server_name = hs.hostname
@@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def remove_from_user_dir(self, user_id):
- def _remove_from_user_dir_txn(txn):
- self._simple_delete_txn(
- txn, table="user_directory", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn, table="user_directory_search", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"user_id": user_id},
- )
- self._simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"other_user_id": user_id},
- )
- txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
-
- return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
-
- @defer.inlineCallbacks
- def get_users_in_dir_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory because they're
- in the given room_id
- """
- user_ids_share_pub = yield self._simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids_share_priv = yield self._simple_select_onecol(
- table="users_who_share_private_rooms",
- keyvalues={"room_id": room_id},
- retcol="other_user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids = set(user_ids_share_pub)
- user_ids.update(user_ids_share_priv)
-
- return user_ids
-
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
@@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_public_rooms")
+ txn.execute("DELETE FROM users_who_share_private_rooms")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+
+ return self.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self._simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("display_name", "avatar_url"),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+
+class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
+
+ # How many records do we calculate before sending it to
+ # add_users_who_share_private_rooms?
+ SHARE_PRIVATE_WORKING_SET = 500
+
+ def __init__(self, db_conn, hs):
+ super(UserDirectoryStore, self).__init__(db_conn, hs)
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self._simple_delete_txn(
+ txn, table="user_directory", keyvalues={"user_id": user_id}
+ )
+ self._simple_delete_txn(
+ txn, table="user_directory_search", keyvalues={"user_id": user_id}
+ )
+ self._simple_delete_txn(
+ txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id},
+ )
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+
+ @defer.inlineCallbacks
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory because they're
+ in the given room_id
+ """
+ user_ids_share_pub = yield self._simple_select_onecol(
+ table="users_in_public_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share_priv = yield self._simple_select_onecol(
+ table="users_who_share_private_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="other_user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_share_pub)
+ user_ids.update(user_ids_share_priv)
+
+ return user_ids
+
def remove_user_who_share_room(self, user_id, room_id):
"""
Deletes entries in the users_who_share_*_rooms table. The first
@@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return [room_id for room_id, in rows]
- def delete_all_from_user_dir(self):
- """Delete the entire user directory
- """
-
- def _delete_all_from_user_dir_txn(txn):
- txn.execute("DELETE FROM user_directory")
- txn.execute("DELETE FROM user_directory_search")
- txn.execute("DELETE FROM users_in_public_rooms")
- txn.execute("DELETE FROM users_who_share_private_rooms")
- txn.call_after(self.get_user_in_directory.invalidate_all)
-
- return self.runInteraction(
- "delete_all_from_user_dir", _delete_all_from_user_dir_txn
- )
-
- @cached()
- def get_user_in_directory(self, user_id):
- return self._simple_select_one(
- table="user_directory",
- keyvalues={"user_id": user_id},
- retcols=("display_name", "avatar_url"),
- allow_none=True,
- desc="get_user_in_directory",
- )
-
def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol(
table="user_directory_stream_pos",
@@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
- def update_user_directory_stream_pos(self, stream_id):
- return self._simple_update_one(
- table="user_directory_stream_pos",
- keyvalues={},
- updatevalues={"stream_id": stream_id},
- desc="update_user_directory_stream_pos",
- )
-
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
index 05cabc2282..aa4f0da5f0 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/user_erasure_store.py
@@ -56,15 +56,15 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- def _get_erased_users(txn):
- txn.execute(
- "SELECT user_id FROM erased_users WHERE user_id IN (%s)"
- % (",".join("?" * len(user_ids))),
- user_ids,
- )
- return set(r[0] for r in txn)
-
- erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
+ rows = yield self._simple_select_many_batch(
+ table="erased_users",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="are_users_erased",
+ )
+ erased_users = set(row["user_id"] for row in rows)
+
res = dict((u, u in erased_users) for u in user_ids)
return res
diff --git a/synapse/types.py b/synapse/types.py
index 00bb0743ff..aafc3ffe74 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -321,6 +321,7 @@ class StreamToken(
)
):
_SEPARATOR = "_"
+ START = None # type: StreamToken
@classmethod
def from_string(cls, string):
@@ -405,7 +406,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after.
"""
- __slots__ = []
+ __slots__ = [] # type: list
@classmethod
def parse(cls, string):
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f1c46836b1..804dbca443 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
@@ -213,7 +217,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -340,10 +346,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -479,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index b50e3503f0..43fd65d693 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,6 +16,7 @@
import logging
import os
+from typing import Dict
import six
from six.moves import intern
@@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
-collectors_by_name = {}
+collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 43f66ec4be..5ac2530a6a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,10 +18,12 @@ import inspect
import logging
import threading
from collections import namedtuple
+from typing import Any, cast
from six import itervalues
from prometheus_client import Gauge
+from typing_extensions import Protocol
from twisted.internet import defer
@@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__)
+class _CachedFunction(Protocol):
+ invalidate = None # type: Any
+ invalidate_all = None # type: Any
+ invalidate_many = None # type: Any
+ prefill = None # type: Any
+ cache = None # type: Any
+ num_args = None # type: Any
+
+ def __name__(self):
+ ...
+
+
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
@@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ def __init__(
+ self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+ ):
self.orig = orig
if inlineCallbacks:
@@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer)
+ wrapped = cast(_CachedFunction, _wrapped)
+
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 9a72218d85..2ea4e4e911 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from typing import Dict
+
from six import itervalues
SENTINEL = object()
@@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self):
self.size = 0
- self.root = {}
+ self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)
diff --git a/synapse/util/hash.py b/synapse/util/hash.py
new file mode 100644
index 0000000000..359168704e
--- /dev/null
+++ b/synapse/util/hash.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+
+import unpaddedbase64
+
+
+def sha256_and_url_safe_base64(input_text):
+ """SHA256 hash an input string, encode the digest as url-safe base64, and
+ return
+
+ :param input_text: string to hash
+ :type input_text: str
+
+ :returns a sha256 hashed and url-safe base64 encoded digest
+ :rtype: str
+ """
+ digest = hashlib.sha256(input_text.encode()).digest()
+ return unpaddedbase64.encode_base64(digest, urlsafe=True)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 0910930c21..4b1bcdf23c 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -60,12 +60,14 @@ in_flight = InFlightGauge(
)
-def measure_func(name):
+def measure_func(name=None):
def wrapper(func):
+ block_name = func.__name__ if name is None else name
+
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
- with Measure(self.clock, name):
+ with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 522acd5aa8..2705cbe5f8 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@
# limitations under the License.
import importlib
+import importlib.util
from synapse.config._base import ConfigError
def load_module(provider):
- """ Loads a module with its config
+ """ Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
@@ -38,3 +39,20 @@ def load_module(provider):
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config
+
+
+def load_python_module(location: str):
+ """Load a python module, and return a reference to its global namespace
+
+ Args:
+ location (str): path to the module
+
+ Returns:
+ python module object
+ """
+ spec = importlib.util.spec_from_file_location(location, location)
+ if spec is None:
+ raise Exception("Unable to load module at %s" % (location,))
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod) # type: ignore
+ return mod
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ global _already_patched
+
+ orig_inline_callbacks = defer.inlineCallbacks
+ if _already_patched:
+ return
+
+ def new_inline_callbacks(f):
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+ changes = [] # type: List[str]
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "%s changed context from %s to %s on exception" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "Completed %s changed context from %s to %s" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (f, LoggingContext.current_context(), start_context)
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
+ _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+ """Wraps a generator that is about to be passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+
+ It's perfectly valid for log contexts to change within a function, e.g. due
+ to new Measure blocks, so such changes are added to the given `changes`
+ list instead of triggering an exception.
+
+ Args:
+ f: generator function to wrap
+ changes: A list of strings detailing how the contexts
+ changed within a function.
+
+ Returns:
+ function
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ @functools.wraps(f)
+ def check_yield_points_inner(*args, **kwargs):
+ gen = f(*args, **kwargs)
+
+ last_yield_line_no = gen.gi_frame.f_lineno
+ result = None # type: Any
+ while True:
+ expected_context = LoggingContext.current_context()
+
+ try:
+ isFailure = isinstance(result, Failure)
+ if isFailure:
+ d = result.throwExceptionIntoGenerator(gen)
+ else:
+ d = gen.send(result)
+ except (StopIteration, defer._DefGen_Return) as e:
+ if LoggingContext.current_context() != expected_context:
+ # This happens when the context is lost sometime *after* the
+ # final yield and returning. E.g. we forgot to yield on a
+ # function that returns a deferred.
+ #
+ # We don't raise here as it's perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
+ % (
+ f.__qualname__,
+ expected_context,
+ LoggingContext.current_context(),
+ f.__code__.co_filename,
+ last_yield_line_no,
+ )
+ )
+ changes.append(err)
+ return getattr(e, "value", None)
+
+ frame = gen.gi_frame
+
+ if isinstance(d, defer.Deferred) and not d.called:
+ # This happens if we yield on a deferred that doesn't follow
+ # the log context rules without wrapping in a `make_deferred_yieldable`.
+ # We raise here as this should never happen.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ err = (
+ "%s yielded with context %s rather than sentinel,"
+ " yielded on line %d in %s"
+ % (
+ frame.f_code.co_name,
+ LoggingContext.current_context(),
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ raise Exception(err)
+
+ try:
+ result = yield d
+ except Exception as e:
+ result = Failure(e)
+
+ if LoggingContext.current_context() != expected_context:
+
+ # This happens because the context is lost sometime *after* the
+ # previous yield and *after* the current yield. E.g. the
+ # deferred we waited on didn't follow the rules, or we forgot to
+ # yield on a function between the two yield points.
+ #
+ # We don't raise here as its perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "%s changed context from %s to %s, happened between lines %d and %d in %s"
+ % (
+ frame.f_code.co_name,
+ expected_context,
+ LoggingContext.current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ changes.append(err)
+
+ last_yield_line_no = frame.f_lineno
+
+ return check_yield_points_inner
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 0862b5ca5a..af69587196 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -22,6 +22,15 @@ from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
+# the intial backoff, after the first transaction fails
+MIN_RETRY_INTERVAL = 10 * 60 * 1000
+
+# how much we multiply the backoff by after each subsequent fail
+RETRY_MULTIPLIER = 5
+
+# a cap on the backoff. (Essentially none)
+MAX_RETRY_INTERVAL = 2 ** 62
+
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
@@ -71,11 +80,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# We aren't ready to retry that destination.
raise
"""
+ failure_ts = None
retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
+ failure_ts = retry_timings["failure_ts"]
retry_last_ts, retry_interval = (
retry_timings["retry_last_ts"],
retry_timings["retry_interval"],
@@ -99,6 +110,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
destination,
clock,
store,
+ failure_ts,
retry_interval,
backoff_on_failure=backoff_on_failure,
**kwargs
@@ -111,10 +123,8 @@ class RetryDestinationLimiter(object):
destination,
clock,
store,
+ failure_ts,
retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5,
backoff_on_404=False,
backoff_on_failure=True,
):
@@ -127,15 +137,11 @@ class RetryDestinationLimiter(object):
destination (str)
clock (Clock)
store (DataStore)
+ failure_ts (int|None): when this destination started failing (in ms since
+ the epoch), or zero if the last request was successful
retry_interval (int): The next retry interval taken from the
database in milliseconds, or zero if the last request was
successful.
- min_retry_interval (int): The minimum retry interval to use after
- a failed request, in milliseconds.
- max_retry_interval (int): The maximum retry interval to use after
- a failed request, in milliseconds.
- multiplier_retry_interval (int): The multiplier to use to increase
- the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the
@@ -145,10 +151,8 @@ class RetryDestinationLimiter(object):
self.store = store
self.destination = destination
+ self.failure_ts = failure_ts
self.retry_interval = retry_interval
- self.min_retry_interval = min_retry_interval
- self.max_retry_interval = max_retry_interval
- self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
@@ -189,6 +193,7 @@ class RetryDestinationLimiter(object):
logger.debug(
"Connection to %s was successful; clearing backoff", self.destination
)
+ self.failure_ts = None
retry_last_ts = 0
self.retry_interval = 0
elif not self.backoff_on_failure:
@@ -196,13 +201,14 @@ class RetryDestinationLimiter(object):
else:
# We couldn't connect.
if self.retry_interval:
- self.retry_interval *= self.multiplier_retry_interval
- self.retry_interval *= int(random.uniform(0.8, 1.4))
+ self.retry_interval = int(
+ self.retry_interval * RETRY_MULTIPLIER * random.uniform(0.8, 1.4)
+ )
- if self.retry_interval >= self.max_retry_interval:
- self.retry_interval = self.max_retry_interval
+ if self.retry_interval >= MAX_RETRY_INTERVAL:
+ self.retry_interval = MAX_RETRY_INTERVAL
else:
- self.retry_interval = self.min_retry_interval
+ self.retry_interval = MIN_RETRY_INTERVAL
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
@@ -213,11 +219,17 @@ class RetryDestinationLimiter(object):
)
retry_last_ts = int(self.clock.time_msec())
+ if self.failure_ts is None:
+ self.failure_ts = retry_last_ts
+
@defer.inlineCallbacks
def store_retry_timings():
try:
yield self.store.set_destination_retry_timings(
- self.destination, retry_last_ts, self.retry_interval
+ self.destination,
+ self.failure_ts,
+ retry_last_ts,
+ self.retry_interval,
)
except Exception:
logger.exception("Failed to store destination_retry_timings")
|