diff --git a/synapse/__init__.py b/synapse/__init__.py
index 8c3d7a210a..ef8853bd24 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.25.1"
+__version__ = "0.26.0"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 79b35b3e7c..aa15f73f36 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -46,6 +46,7 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
+ THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
@@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
pass
+class FederationDeniedError(SynapseError):
+ """An error raised when the server tries to federate with a server which
+ is not on its federation whitelist.
+
+ Attributes:
+ destination (str): The destination which has been denied
+ """
+
+ def __init__(self, destination):
+ """Raised by federation client or server to indicate that we are
+ are deliberately not attempting to contact a given server because it is
+ not on our federation whitelist.
+
+ Args:
+ destination (str): the domain in question
+ """
+
+ self.destination = destination
+
+ super(FederationDeniedError, self).__init__(
+ code=403,
+ msg="Federation denied with %s." % (self.destination,),
+ errcode=Codes.FORBIDDEN,
+ )
+
+
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 7d0c2879ae..c6fe4516d1 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -49,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dc3f6efd43..3b3352798d 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index a072291e1f..4de43c41f0 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 09e9488f06..f760826d27 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index ae531c0aa4..e32ee8fe93 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 92ab3b311b..cb82a415a6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(config_options):
"""
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index eab1597aaa..1ed1ca8772 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 7fbbb0b0e1..32ccea3f13 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -81,19 +81,6 @@ class PusherSlaveStore(
class PusherServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 0abba3016e..f87531f1b6 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 3bd7ef7bba..0f0ddfa78a 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -184,6 +184,9 @@ def main():
worker_configfiles.append(worker_configfile)
if options.all_processes:
+ # To start the main synapse with -a you need to add a worker file
+ # with worker_app == "synapse.app.homeserver"
+ start_stop_synapse = False
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
@@ -200,11 +203,29 @@ def main():
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
- worker_pidfile = worker_config["worker_pid_file"]
- worker_daemonize = worker_config["worker_daemonize"]
- assert worker_daemonize, "In config %r: expected '%s' to be True" % (
- worker_configfile, "worker_daemonize")
- worker_cache_factor = worker_config.get("synctl_cache_factor")
+ if worker_app == "synapse.app.homeserver":
+ # We need to special case all of this to pick up options that may
+ # be set in the main config file or in this worker config file.
+ worker_pidfile = (
+ worker_config.get("pid_file")
+ or pidfile
+ )
+ worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
+ daemonize = worker_config.get("daemonize") or config.get("daemonize")
+ assert daemonize, "Main process must have daemonize set to true"
+
+ # The master process doesn't support using worker_* config.
+ for key in worker_config:
+ if key == "worker_app": # But we allow worker_app
+ continue
+ assert not key.startswith("worker_"), \
+ "Main process cannot use worker_* config"
+ else:
+ worker_pidfile = worker_config["worker_pid_file"]
+ worker_daemonize = worker_config["worker_daemonize"]
+ assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+ worker_configfile, "worker_daemonize")
+ worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index a48c4a2ae6..494ccb702c 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index a1d6e4d4f7..3f70039acd 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
- precise:
- format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
-- %(message)s'
+ precise:
+ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
+%(request)s - %(message)s'
filters:
- context:
- (): synapse.util.logcontext.LoggingContextFilter
- request: ""
+ context:
+ (): synapse.util.logcontext.LoggingContextFilter
+ request: ""
handlers:
- file:
- class: logging.handlers.RotatingFileHandler
- formatter: precise
- filename: ${log_file}
- maxBytes: 104857600
- backupCount: 10
- filters: [context]
- console:
- class: logging.StreamHandler
- formatter: precise
- filters: [context]
+ file:
+ class: logging.handlers.RotatingFileHandler
+ formatter: precise
+ filename: ${log_file}
+ maxBytes: 104857600
+ backupCount: 10
+ filters: [context]
+ console:
+ class: logging.StreamHandler
+ formatter: precise
+ filters: [context]
loggers:
synapse:
@@ -74,17 +74,10 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
- log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
- # Logging verbosity level. Ignored if log_config is specified.
- verbose: 0
-
- # File to write logging to. Ignored if log_config is specified.
- log_file: "%(log_file)s"
-
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
@@ -123,9 +116,10 @@ class LoggingConfig(Config):
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
+ log_file = self.abspath("homeserver.log")
with open(log_config, "wb") as log_config_file:
log_config_file.write(
- DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
+ DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
@@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
)
if log_config is None:
+ # We don't have a logfile, so fall back to the 'verbosity' param from
+ # the config or cmdline. (Note that we generate a log config for new
+ # installs, so this will be an unusual case)
level = logging.INFO
level_for_storage = logging.INFO
if config.verbosity:
@@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1:
level_for_storage = logging.DEBUG
- # FIXME: we need a logging.WARN for a -q quiet option
logger = logging.getLogger('')
logger.setLevel(level)
- logging.getLogger('synapse.storage').setLevel(level_for_storage)
+ logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
formatter = logging.Formatter(log_format)
if log_file:
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ef917fc9f2..336959094b 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
+ self.registrations_require_3pid = config.get("registrations_require_3pid", [])
+ self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -52,6 +54,23 @@ class RegistrationConfig(Config):
# Enable registration for new users.
enable_registration: False
+ # The user must provide all of the below types of 3PID when registering.
+ #
+ # registrations_require_3pid:
+ # - email
+ # - msisdn
+
+ # Mandate that users are only allowed to associate certain formats of
+ # 3PIDs with accounts on this server.
+ #
+ # allowed_local_3pids:
+ # - medium: email
+ # pattern: ".*@matrix\\.org"
+ # - medium: email
+ # pattern: ".*@vector\\.im"
+ # - medium: msisdn
+ # pattern: "\\+44"
+
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 6baa474931..25ea77738a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -16,6 +16,8 @@
from ._base import Config, ConfigError
from collections import namedtuple
+from synapse.util.module_loader import load_module
+
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
+MediaStorageProviderConfig = namedtuple(
+ "MediaStorageProviderConfig", (
+ "store_local", # Whether to store newly uploaded local files
+ "store_remote", # Whether to store newly downloaded remote files
+ "store_synchronous", # Whether to wait for successful storage for local uploads
+ ),
+)
+
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
@@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
self.media_store_path = self.ensure_directory(config["media_store_path"])
- self.backup_media_store_path = config.get("backup_media_store_path")
- if self.backup_media_store_path:
- self.backup_media_store_path = self.ensure_directory(
- self.backup_media_store_path
- )
+ backup_media_store_path = config.get("backup_media_store_path")
- self.synchronous_backup_media_store = config.get(
+ synchronous_backup_media_store = config.get(
"synchronous_backup_media_store", False
)
+ storage_providers = config.get("media_storage_providers", [])
+
+ if backup_media_store_path:
+ if storage_providers:
+ raise ConfigError(
+ "Cannot use both 'backup_media_store_path' and 'storage_providers'"
+ )
+
+ storage_providers = [{
+ "module": "file_system",
+ "store_local": True,
+ "store_synchronous": synchronous_backup_media_store,
+ "store_remote": True,
+ "config": {
+ "directory": backup_media_store_path,
+ }
+ }]
+
+ # This is a list of config that can be used to create the storage
+ # providers. The entries are tuples of (Class, class_config,
+ # MediaStorageProviderConfig), where Class is the class of the provider,
+ # the class_config the config to pass to it, and
+ # MediaStorageProviderConfig are options for StorageProviderWrapper.
+ #
+ # We don't create the storage providers here as not all workers need
+ # them to be started.
+ self.media_storage_providers = []
+
+ for provider_config in storage_providers:
+ # We special case the module "file_system" so as not to need to
+ # expose FileStorageProviderBackend
+ if provider_config["module"] == "file_system":
+ provider_config["module"] = (
+ "synapse.rest.media.v1.storage_provider"
+ ".FileStorageProviderBackend"
+ )
+
+ provider_class, parsed_config = load_module(provider_config)
+
+ wrapper_config = MediaStorageProviderConfig(
+ provider_config.get("store_local", False),
+ provider_config.get("store_remote", False),
+ provider_config.get("store_synchronous", False),
+ )
+
+ self.media_storage_providers.append(
+ (provider_class, parsed_config, wrapper_config,)
+ )
+
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
- # A secondary directory where uploaded images and attachments are
- # stored as a backup.
- # backup_media_store_path: "%(media_store)s"
-
- # Whether to wait for successful write to backup media store before
- # returning successfully.
- # synchronous_backup_media_store: false
+ # Media storage providers allow media to be stored in different
+ # locations.
+ # media_storage_providers:
+ # - module: file_system
+ # # Whether to write new local files.
+ # store_local: false
+ # # Whether to write new remote media
+ # store_remote: false
+ # # Whether to block upload requests waiting for write to this
+ # # provider to complete
+ # store_synchronous: false
+ # config:
+ # directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 436dd8a6fe..8f0b6d1f28 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -55,6 +55,17 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
+ # FIXME: federation_domain_whitelist needs sytests
+ self.federation_domain_whitelist = None
+ federation_domain_whitelist = config.get(
+ "federation_domain_whitelist", None
+ )
+ # turn the whitelist into a hash for speed of lookup
+ if federation_domain_whitelist is not None:
+ self.federation_domain_whitelist = {}
+ for domain in federation_domain_whitelist:
+ self.federation_domain_whitelist[domain] = True
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -210,6 +221,17 @@ class ServerConfig(Config):
# (except those sent by local server admins). The default is False.
# block_non_admin_invites: True
+ # Restrict federation to the following whitelist of domains.
+ # N.B. we recommend also firewalling your federation listener to limit
+ # inbound federation traffic as early as possible, rather than relying
+ # purely on this application-layer restriction. If not specified, the
+ # default is to whitelist everything.
+ #
+ # federation_domain_whitelist:
+ # - lon.example.com
+ # - nyc.example.com
+ # - syd.example.com
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4748f71c2f..29eb012ddb 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -96,7 +96,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds the fingerprint of its own certificate
- # to the list. So if federation traffic is handle directly by synapse
+ # to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c5a5a8919c..4b6884918d 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -23,6 +23,11 @@ class WorkerConfig(Config):
def read_config(self, config):
self.worker_app = config.get("worker_app")
+
+ # Canonicalise worker_app so that master always has None
+ if self.worker_app == "synapse.app.homeserver":
+ self.worker_app = None
+
self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 061ee86b16..cd5627e36a 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
- 403, "You cannot unban user &s." % (target_user_id,)
+ 403, "You cannot unban user %s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a0f5d40eb3..7918d3e442 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -16,7 +16,9 @@ import logging
from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash
+from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_request
from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer
@@ -169,3 +171,28 @@ class FederationBase(object):
)
return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+ """Construct a FrozenEvent from an event json received over federation
+
+ Args:
+ pdu_json (object): pdu as received over federation
+ outlier (bool): True to mark this event as an outlier
+
+ Returns:
+ FrozenEvent
+
+ Raises:
+ SynapseError: if the pdu is missing required fields
+ """
+ # we could probably enforce a bunch of other fields here (room_id, sender,
+ # origin, etc etc)
+ assert_params_in_request(pdu_json, ('event_id', 'type'))
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b8f02f5391..813907f7f2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,29 +14,29 @@
# limitations under the License.
+import copy
+import itertools
+import logging
+import random
+
from twisted.internet import defer
-from .federation_base import FederationBase
from synapse.api.constants import Membership
-
from synapse.api.errors import (
- CodeMessageException, HttpResponseException, SynapseError,
+ CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
)
-from synapse.util import unwrapFirstError, logcontext
+from synapse.events import builder
+from synapse.federation.federation_base import (
+ FederationBase,
+ event_from_pdu_json,
+)
+import synapse.metrics
+from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logutils import log_function
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
+from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
-import copy
-import itertools
-import logging
-import random
-
-
logger = logging.getLogger(__name__)
@@ -184,7 +184,7 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
- self.event_from_pdu_json(p, outlier=False)
+ event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
@@ -244,7 +244,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
+ event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
@@ -266,6 +266,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e.message)
+ continue
except Exception as e:
pdu_attempts[destination] = now
@@ -336,11 +339,11 @@ class FederationClient(FederationBase):
)
pdus = [
- self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
@@ -441,7 +444,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
@@ -570,12 +573,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
@@ -650,7 +653,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict)
- pdu = self.event_from_pdu_json(pdu_dict)
+ pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
@@ -740,7 +743,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -788,7 +791,7 @@ class FederationClient(FederationBase):
)
events = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content.get("events", [])
]
@@ -805,15 +808,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events)
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a2327f24b6..9849953c9b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -12,25 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
-from .federation_base import FederationBase
-from .units import Transaction, Edu
+import simplejson as json
+from twisted.internet import defer
+from synapse.api.errors import AuthError, FederationError, SynapseError
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import (
+ FederationBase,
+ event_from_pdu_json,
+)
+from synapse.federation.units import Edu, Transaction
+import synapse.metrics
+from synapse.types import get_domain_from_id
from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
-import synapse.metrics
-
-from synapse.api.errors import AuthError, FederationError, SynapseError
-
-from synapse.crypto.event_signing import compute_event_signature
-
-import simplejson as json
-import logging
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
@@ -172,7 +171,7 @@ class FederationServer(FederationBase):
p["age_ts"] = request_time - int(p["age"])
del p["age"]
- event = self.event_from_pdu_json(p)
+ event = event_from_pdu_json(p)
room_id = event.room_id
pdus_by_room.setdefault(room_id, []).append(event)
@@ -346,7 +345,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -354,7 +353,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -374,7 +373,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -411,7 +410,7 @@ class FederationServer(FederationBase):
"""
with (yield self._server_linearizer.queue((origin, room_id))):
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -586,15 +585,6 @@ class FederationServer(FederationBase):
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def exchange_third_party_invite(
self,
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 3e7809b04f..a141ec9953 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import HttpResponseException, FederationDeniedError
from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
+events_processed_counter = client_metrics.register_counter("events_processed")
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -205,6 +207,8 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
+ events_processed_counter.inc_by(len(events))
+
yield self.store.update_federation_out_pos(
"events", next_token
)
@@ -486,6 +490,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
+ except FederationDeniedError as e:
+ logger.info(e)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1f3ce238f6..5488e82985 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -212,6 +212,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if the remote destination
+ is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2b02b021ec..06c16ba4fa 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -81,6 +81,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
@@ -92,6 +93,12 @@ class Authenticator(object):
"signatures": {},
}
+ if (
+ self.federation_domain_whitelist is not None and
+ self.server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(self.server_name)
+
if content is not None:
json_request["content"] = content
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index feca3e4c10..3dd3fa2a27 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+import synapse
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@@ -23,6 +24,10 @@ import logging
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+events_processed_counter = metrics.register_counter("events_processed")
+
def log_failure(failure):
logger.error(
@@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
service, event
)
+ events_processed_counter.inc_by(len(events))
+
yield self.store.set_appservice_last_pos(upper_bound)
finally:
self.is_processing = False
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 573c9db8a1..258cc345dc 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from twisted.internet import defer, threads
from ._base import BaseHandler
from synapse.api.constants import LoginType
@@ -25,6 +25,7 @@ from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable
from twisted.web.client import PartialDownloadError
@@ -714,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
- result = self.validate_hash(password, password_hash)
+ result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None)
@@ -842,10 +843,13 @@ class AuthHandler(BaseHandler):
password (str): Password to hash.
Returns:
- Hashed password (str).
+ Deferred(str): Hashed password.
"""
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- bcrypt.gensalt(self.bcrypt_rounds))
+ def _do_hash():
+ return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
+
+ return make_deferred_yieldable(threads.deferToThread(_do_hash))
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -855,13 +859,17 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value.
Returns:
- Whether self.hash(password) == stored_hash (bool).
+ Deferred(bool): Whether self.hash(password) == stored_hash.
"""
- if stored_hash:
+
+ def _do_validate_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')) == stored_hash
+
+ if stored_hash:
+ return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
else:
- return False
+ return defer.succeed(False)
class MacaroonGeneartor(object):
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 2152efc692..0e83453851 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.api import errors
from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will
# eventually become consistent.
return
+ except FederationDeniedError as e:
+ logger.info(e)
+ return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..d996aa90bb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string
@@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler(
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise SynapseError(400, "Not a user here")
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..9aa95f89e6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,8 +19,10 @@ import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
+from synapse.api.errors import (
+ SynapseError, CodeMessageException, FederationDeniedError,
+)
+from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
@@ -32,7 +34,7 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
@@ -70,7 +72,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
@@ -139,6 +142,10 @@ class E2eKeysHandler(object):
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
+ except FederationDeniedError as e:
+ failures[destination] = {
+ "status": 403, "message": "Federation Denied",
+ }
except Exception as e:
# include ConnectionRefused and other errors
failures[destination] = {
@@ -170,7 +177,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
- if not self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -213,7 +221,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ac70730885..677532c87b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -22,6 +22,7 @@ from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
+ FederationDeniedError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
@@ -782,6 +783,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e)
+ continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7e5d3f148d..e4d0cc8b02 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result})
else:
- result = yield self.transport_client.get_publicised_groups_for_user(
- get_domain_from_id(user_id), user_id
+ bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ get_domain_from_id(user_id), [user_id],
)
+ result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
- defer.returnValue(result)
+ defer.returnValue({"groups": result})
@defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 4bc6ef51fe..9021d4d57f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
from synapse import types
from synapse.types import UserID
from synapse.util.async import run_on_reactor
+from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor()
password_hash = None
if password:
- password_hash = self.auth_handler().hash(password)
+ password_hash = yield self.auth_handler().hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating theeepidcred sid %s on id server %s",
+ logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
@@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
+ if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+ raise RegistrationError(
+ 403, "Third party identifier is not allowed"
+ )
+
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index bb40075387..dfa09141ed 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
if limit:
step = limit + 1
else:
- step = len(rooms_to_scan)
+ # step cannot be zero
+ step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = []
for i in xrange(0, len(rooms_to_scan), step):
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 44414e1dc1..e057ae54c9 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
- password_hash = self._auth_handler.hash(newpassword)
+ password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4abb479ae3..f3e4973c2e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext
import synapse.metrics
@@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError,
+ HTTPConnectionPool,
)
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss
@@ -64,13 +66,23 @@ class SimpleHttpClient(object):
"""
def __init__(self, hs):
self.hs = hs
+
+ pool = HTTPConnectionPool(reactor)
+
+ # the pusher makes lots of concurrent SSL connections to sygnal, and
+ # tends to do so in batches, so we need to allow the pool to keep lots
+ # of idle connections around.
+ pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+ pool.cachedConnectionTimeout = 2 * 60
+
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory()
+ contextFactory=hs.get_http_client_context_factory(),
+ pool=pool,
)
self.user_agent = hs.version_string
self.clock = hs.get_clock()
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index e2b99ef3bd..87639b9151 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
def eb(res, record_type):
if res.check(DNSNameError):
return []
- logger.warn("Error looking up %s for %s: %s",
- record_type, host, res, res.value)
+ logger.warn("Error looking up %s for %s: %s", record_type, host, res)
return res
# no logcontexts here, so we can safely fire these off and gatherResults
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 833496b72d..9145405cb0 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -27,7 +27,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json
from synapse.api.errors import (
- SynapseError, Codes, HttpResponseException,
+ SynapseError, Codes, HttpResponseException, FederationDeniedError,
)
from signedjson.sign import sign_json
@@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
+
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
"""
+ if (
+ self.hs.config.federation_domain_whitelist and
+ destination not in self.hs.config.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(destination)
+
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
@@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
if not json_data_callback:
@@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
@@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
logger.debug("get_json args: %s", args)
@@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
response = yield self._request(
@@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
encoded_args = {}
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 25466cd292..165c684d0d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -42,36 +42,70 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
-incoming_requests_counter = metrics.register_counter(
- "requests",
+# total number of responses served, split by method/servlet/tag
+response_count = metrics.register_counter(
+ "response_count",
labels=["method", "servlet", "tag"],
+ alternative_names=(
+ # the following are all deprecated aliases for the same metric
+ metrics.name_prefix + x for x in (
+ "_requests",
+ "_response_time:count",
+ "_response_ru_utime:count",
+ "_response_ru_stime:count",
+ "_response_db_txn_count:count",
+ "_response_db_txn_duration:count",
+ )
+ )
)
+
outgoing_responses_counter = metrics.register_counter(
"responses",
labels=["method", "code"],
)
-response_timer = metrics.register_distribution(
- "response_time",
- labels=["method", "servlet", "tag"]
+response_timer = metrics.register_counter(
+ "response_time_seconds",
+ labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_time:total",
+ ),
)
-response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet", "tag"]
+response_ru_utime = metrics.register_counter(
+ "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_ru_utime:total",
+ ),
)
-response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet", "tag"]
+response_ru_stime = metrics.register_counter(
+ "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_ru_stime:total",
+ ),
)
-response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet", "tag"]
+response_db_txn_count = metrics.register_counter(
+ "response_db_txn_count", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_db_txn_count:total",
+ ),
)
-response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet", "tag"]
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = metrics.register_counter(
+ "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_db_txn_duration:total",
+ ),
)
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+ "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
_next_request_id = 0
@@ -107,6 +141,10 @@ def wrap_request_handler(request_handler, include_metrics=False):
with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics()
+ # we start the request metrics timer here with an initial stab
+ # at the servlet name. For most requests that name will be
+ # JsonResource (or a subclass), and JsonResource._async_render
+ # will update it once it picks a servlet.
request_metrics.start(self.clock, name=self.__class__.__name__)
request_context.request = request_id
@@ -249,12 +287,23 @@ class JsonResource(HttpServer, resource.Resource):
if not m:
continue
- # We found a match! Trigger callback and then return the
- # returned response. We pass both the request and any
- # matched groups from the regex to the callback.
+ # We found a match! First update the metrics object to indicate
+ # which servlet is handling the request.
callback = path_entry.callback
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
+
+ request_metrics.name = servlet_classname
+
+ # Now trigger the callback. If it returns a response, we send it
+ # here. If it throws an exception, that is handled by the wrapper
+ # installed by @request_handler.
+
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
@@ -265,30 +314,14 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
-
- request_metrics.name = servlet_classname
-
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+ request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- # could alternatively use request.notifyFinish() and flip a flag when
- # the Deferred fires, but since the flag is RIGHT THERE it seems like
- # a waste.
- if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
- return
-
outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it.
@@ -322,7 +355,7 @@ class RequestMetrics(object):
)
return
- incoming_requests_counter.inc(request.method, self.name, tag)
+ response_count.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
@@ -341,7 +374,10 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, self.name, tag
+ context.db_txn_duration_ms / 1000., request.method, self.name, tag
+ )
+ response_db_sched_duration.inc_by(
+ context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
@@ -364,6 +400,15 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warn(
+ "Not sending response to request %s, already disconnected.",
+ request)
+ return
+
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index cd1492b1c3..e422c8dfae 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
+ db_txn_duration_ms = context.db_txn_duration_ms
+ db_sched_duration_ms = context.db_sched_duration_ms
except Exception:
ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
+ db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
+ " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
- int(db_txn_duration * 1000),
+ db_sched_duration_ms,
+ db_txn_duration_ms,
int(db_txn_count),
self.sentLength,
self.code,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2265e6e8d6..e0cfb7d08f 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
-
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
+
+ # record the amount of wallclock time spent running pending calls.
+ # This is a proxy for the actual amount of time between reactor polls,
+ # since about 25% of time is actually spent running things triggered by
+ # I/O events, but that is harder to capture without rewriting half the
+ # reactor.
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..1e783e5ff4 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -15,18 +15,38 @@
from itertools import chain
+import logging
+logger = logging.getLogger(__name__)
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
- # flatten a list-of-lists
- return list(chain.from_iterable(map(func, items)))
+
+def flatten(items):
+ """Flatten a list of lists
+
+ Args:
+ items: iterable[iterable[X]]
+
+ Returns:
+ list[X]: flattened list
+ """
+ return list(chain.from_iterable(items))
class BaseMetric(object):
+ """Base class for metrics which report a single value per label set
+ """
- def __init__(self, name, labels=[]):
- self.name = name
+ def __init__(self, name, labels=[], alternative_names=[]):
+ """
+ Args:
+ name (str): principal name for this metric
+ labels (list(str)): names of the labels which will be reported
+ for this metric
+ alternative_names (iterable(str)): list of alternative names for
+ this metric. This can be useful to provide a migration path
+ when renaming metrics.
+ """
+ self._names = [name] + list(alternative_names)
self.labels = labels # OK not to clone as we never write it
def dimension(self):
@@ -36,7 +56,7 @@ class BaseMetric(object):
return not len(self.labels)
def _render_labelvalue(self, value):
- # TODO: some kind of value escape
+ # TODO: escape backslashes, quotes and newlines
return '"%s"' % (value)
def _render_key(self, values):
@@ -47,19 +67,60 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)])
)
+ def _render_for_labels(self, label_values, value):
+ """Render this metric for a single set of labels
+
+ Args:
+ label_values (list[str]): values for each of the labels
+ value: value of the metric at with these labels
+
+ Returns:
+ iterable[str]: rendered metric
+ """
+ rendered_labels = self._render_key(label_values)
+ return (
+ "%s%s %.12g" % (name, rendered_labels, value)
+ for name in self._names
+ )
+
+ def render(self):
+ """Render this metric
+
+ Each metric is rendered as:
+
+ name{label1="val1",label2="val2"} value
+
+ https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
+
+ Returns:
+ iterable[str]: rendered metrics
+ """
+ raise NotImplementedError()
+
class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing
- integer that counts events."""
+ value that counts events or running totals.
+
+ Example use cases for Counters:
+ - Number of requests processed
+ - Number of items that were inserted into a queue
+ - Total amount of data that a system has processed
+ Counters can only go up (and be reset when the process restarts).
+ """
def __init__(self, *args, **kwargs):
super(CounterMetric, self).__init__(*args, **kwargs)
+ # dict[list[str]]: value for each set of label values. the keys are the
+ # label values, in the same order as the labels in self.labels.
+ #
+ # (if the metric is a scalar, the (single) key is the empty list).
self.counts = {}
# Scalar metrics are never empty
if self.is_scalar():
- self.counts[()] = 0
+ self.counts[()] = 0.
def inc_by(self, incr, *values):
if len(values) != self.dimension():
@@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
def inc(self, *values):
self.inc_by(1, *values)
- def render_item(self, k):
- return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
-
def render(self):
- return map_concat(self.render_item, sorted(self.counts.keys()))
+ return flatten(
+ self._render_for_labels(k, self.counts[k])
+ for k in sorted(self.counts.keys())
+ )
class CallbackMetric(BaseMetric):
@@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
self.callback = callback
def render(self):
- value = self.callback()
+ try:
+ value = self.callback()
+ except Exception:
+ logger.exception("Failed to render %s", self.name)
+ return ["# FAILED to render " + self.name]
if self.is_scalar():
- return ["%s %.12g" % (self.name, value)]
+ return list(self._render_for_labels([], value))
- return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
- for k in sorted(value.keys())]
+ return flatten(
+ self._render_for_labels(k, value[k])
+ for k in sorted(value.keys())
+ )
class DistributionMetric(object):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index c16f61452c..2cbac571b8 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -13,21 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from synapse.push import PusherConfigException
+import logging
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-import logging
import push_rule_evaluator
import push_tools
-
+import synapse
+from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+http_push_processed_counter = metrics.register_counter(
+ "http_pushes_processed",
+)
+
+http_push_failed_counter = metrics.register_counter(
+ "http_pushes_failed",
+)
+
class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@@ -152,9 +161,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
+ logger.info(
+ "Processing %i unprocessed push actions for %s starting at "
+ "stream_ordering %s",
+ len(unprocessed), self.name, self.last_stream_ordering,
+ )
+
for push_action in unprocessed:
processed = yield self._process_one(push_action)
if processed:
+ http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -169,6 +185,7 @@ class HttpPusher(object):
self.failing_since
)
else:
+ http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
@@ -316,7 +333,10 @@ class HttpPusher(object):
try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
except Exception:
- logger.warn("Failed to push %s ", self.url)
+ logger.warn(
+ "Failed to push event %s to %s",
+ event.event_id, self.name, exc_info=True,
+ )
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
@@ -325,7 +345,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _send_badge(self, badge):
- logger.info("Sending updated badge count %d to %r", badge, self.user_id)
+ logger.info("Sending updated badge count %d to %s", badge, self.name)
d = {
'notification': {
'id': '',
@@ -347,7 +367,10 @@ class HttpPusher(object):
try:
resp = yield self.http_client.post_json_get_json(self.url, d)
except Exception:
- logger.exception("Failed to push %s ", self.url)
+ logger.warn(
+ "Failed to send badge count to %s",
+ self.name, exc_info=True,
+ )
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d59503b905..0a9a290af4 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote")
def on_RDATA(self, cmd):
+ stream_name = cmd.stream_name
+ inbound_rdata_count.inc(stream_name)
+
try:
- row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+ row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r",
- self.id(), cmd.stream_name, cmd.row
+ self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
- self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+ self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(cmd.stream_name, [])
+ rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+ self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token)
@@ -644,3 +647,9 @@ metrics.register_callback(
},
labels=["command", "name", "conn_id"],
)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = metrics.register_counter(
+ "inbound_rdata_count",
+ labels=["stream_name"],
+)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5669ecb724..45844aa2d2 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- if 'medium' not in identifier or 'address' not in identifier:
+ address = identifier.get('address')
+ medium = identifier.get('medium')
+
+ if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- address = identifier['address']
- if identifier['medium'] == 'email':
+ if medium == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- identifier['medium'], address
+ medium, address,
)
if not user_id:
+ logger.warn(
+ "unknown 3pid identifier medium %s, address %r",
+ medium, address,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 32ed1d3ab2..5c5fa8f7ab 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers()
def on_GET(self, request):
+
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]}
- )
+ ])
else:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if require_email or not require_msisdn:
+ flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
- },
+ }
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.PASSWORD
}
- ]}
- )
+ ])
+ return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 75b735b47d..867ec8602c 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -195,15 +195,20 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if 'ts' in request.args and requester.app_service:
+ event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event(
requester,
- {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- },
+ event_dict,
txn_id=txn_id,
)
@@ -487,13 +492,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(RoomEventServlet, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ event = yield self.event_handler.get_event(requester.user, event_id)
+
+ time_now = self.clock.time_msec()
+ if event:
+ defer.returnValue((200, serialize_event(event, time_now)))
+ else:
+ defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
- super(RoomEventContext, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@@ -803,4 +830,5 @@ def register_servlets(hs, http_server):
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventContext(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 385a3ad2ec..30523995af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
)
@@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index e9d88a8895..c6f4680a76 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
)
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
@@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- flows = [
- [LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.RECAPTCHA]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email:
+ flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else:
- flows = [
- [LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.DUMMY]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email or require_msisdn:
+ flows.extend([[LoginType.MSISDN]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN],
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
])
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
+ # Check that we're not trying to register a denied 3pid.
+ #
+ # the user-facing checks will probably already have happened in
+ # /register/email/requestToken when we requested a 3pid, but that's not
+ # guaranteed.
+
+ if auth_result:
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ if not check_3pid_allowed(self.hs, medium, address):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed",
+ Codes.THREEPID_DENIED,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index cc2842aa72..17e6079cba 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ logger.debug("Federation denied with %s", server_name)
+ continue
+
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 95fa95fce3..e7ac01da01 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
- if upload_name:
- if is_ascii(upload_name):
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename=%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
- else:
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename*=utf-8''%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
- request.setHeader(
- b"Content-Length", b"%d" % (file_size,)
- )
+ add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
yield logcontext.make_deferred_yieldable(
@@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
finish_request(request)
else:
respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+ """Adds the correct response headers in preparation for responding with the
+ media.
+
+ Args:
+ request (twisted.web.http.Request)
+ media_type (str): The media/content type.
+ file_size (int): Size in bytes of the media, if known.
+ upload_name (str): The name of the requested file, if any.
+ """
+ request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ if is_ascii(upload_name):
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+ else:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename*=utf-8''%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+
+ # cache for at least a day.
+ # XXX: we might want to turn this off for data we don't want to
+ # recommend caching as it's sensitive or private - or at least
+ # select private. don't bother setting Expires as all our
+ # clients are smart enough to be happy with Cache-Control
+ request.setHeader(
+ b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+ )
+
+ request.setHeader(
+ b"Content-Length", b"%d" % (file_size,)
+ )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+ """Responds to the request with given responder. If responder is None then
+ returns 404.
+
+ Args:
+ request (twisted.web.http.Request)
+ responder (Responder|None)
+ media_type (str): The media/content type.
+ file_size (int|None): Size in bytes of the media. If not known it should be None
+ upload_name (str|None): The name of the requested file, if any.
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ add_file_headers(request, media_type, file_size, upload_name)
+ with responder:
+ yield responder.write_to_consumer(request)
+ finish_request(request)
+
+
+class Responder(object):
+ """Represents a response that can be streamed to the requester.
+
+ Responder is a context manager which *must* be used, so that any resources
+ held can be cleaned up.
+ """
+ def write_to_consumer(self, consumer):
+ """Stream response into consumer
+
+ Args:
+ consumer (IConsumer)
+
+ Returns:
+ Deferred: Resolves once the response has finished being written
+ """
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+class FileInfo(object):
+ """Details about a requested/uploaded file.
+
+ Attributes:
+ server_name (str): The server name where the media originated from,
+ or None if local.
+ file_id (str): The local ID of the file. For local files this is the
+ same as the media_id
+ url_cache (bool): If the file is for the url preview cache
+ thumbnail (bool): Whether the file is a thumbnail or not.
+ thumbnail_width (int)
+ thumbnail_height (int)
+ thumbnail_method (str)
+ thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ """
+ def __init__(self, server_name, file_id, url_cache=False,
+ thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+ thumbnail_method=None, thumbnail_type=None):
+ self.server_name = server_name
+ self.file_id = file_id
+ self.url_cache = url_cache
+ self.thumbnail = thumbnail
+ self.thumbnail_width = thumbnail_width
+ self.thumbnail_height = thumbnail_height
+ self.thumbnail_method = thumbnail_method
+ self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6879249c8a..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -14,7 +14,7 @@
# limitations under the License.
import synapse.http.servlet
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
- self.store = hs.get_datastore()
- self.version_string = hs.version_string
+
+ # Both of these are expected by @request_handler()
self.clock = hs.get_clock()
+ self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
@@ -57,59 +57,16 @@ class DownloadResource(Resource):
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id, name)
+ yield self.media_repo.get_local_media(request, media_id, name)
else:
- yield self._respond_remote_file(
- request, server_name, media_id, name
- )
-
- @defer.inlineCallbacks
- def _respond_local_file(self, request, media_id, name):
- media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
- respond_404(request)
- return
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_filepath(media_id)
- else:
- file_path = self.filepaths.local_media_filepath(media_id)
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
-
- @defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id, name):
- # don't forward requests for remote media if allow_remote is false
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True)
- if not allow_remote:
- logger.info(
- "Rejecting request for remote media %s/%s due to allow_remote",
- server_name, media_id,
- )
- respond_404(request)
- return
-
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- filesystem_id = media_info["filesystem_id"]
- upload_name = name if name else media_info["upload_name"]
-
- file_path = self.filepaths.remote_media_filepath(
- server_name, filesystem_id
- )
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
+ allow_remote = synapse.http.servlet.parse_boolean(
+ request, "allow_remote", default=True)
+ if not allow_remote:
+ logger.info(
+ "Rejecting request for remote media %s/%s due to allow_remote",
+ server_name, media_id,
+ )
+ respond_404(request)
+ return
+
+ yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index eed9056a2f..485db8577a 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
+from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
@@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer
+from .storage_provider import StorageProviderWrapper
+from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError, HttpResponseException, \
- NotFoundError
+from synapse.api.errors import (
+ SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
import os
@@ -47,7 +52,7 @@ import urlparse
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -63,96 +68,62 @@ class MediaRepository(object):
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
- self.backup_base_path = hs.config.backup_media_store_path
-
- self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
-
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
- self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
- )
-
- @defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
- self.recently_accessed_remotes = set()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
- yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
- )
+ # List of StorageProviders where we should search for media and
+ # potentially upload to.
+ storage_providers = []
- @staticmethod
- def _makedirs(filepath):
- dirname = os.path.dirname(filepath)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ backend = clz(hs, provider_config)
+ provider = StorageProviderWrapper(
+ backend,
+ store_local=wrapper_config.store_local,
+ store_remote=wrapper_config.store_remote,
+ store_synchronous=wrapper_config.store_synchronous,
+ )
+ storage_providers.append(provider)
- @staticmethod
- def _write_file_synchronously(source, fname):
- """Write `source` to the path `fname` synchronously. Should be called
- from a thread.
+ self.media_storage = MediaStorage(
+ self.primary_base_path, self.filepaths, storage_providers,
+ )
- Args:
- source: A file like object to be written
- fname (str): Path to write to
- """
- MediaRepository._makedirs(fname)
- source.seek(0) # Ensure we read from the start of the file
- with open(fname, "wb") as f:
- shutil.copyfileobj(source, f)
+ self.clock.looping_call(
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
+ )
@defer.inlineCallbacks
- def write_to_file_and_backup(self, source, path):
- """Write `source` to the on disk media store, and also the backup store
- if configured.
-
- Args:
- source: A file like object that should be written
- path (str): Relative path to write file to
-
- Returns:
- Deferred[str]: the file path written to in the primary media store
- """
- fname = os.path.join(self.primary_base_path, path)
-
- # Write to the main repository
- yield make_deferred_yieldable(threads.deferToThread(
- self._write_file_synchronously, source, fname,
- ))
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
+ self.recently_accessed_remotes = set()
- # Write to backup repository
- yield self.copy_to_backup(path)
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
- defer.returnValue(fname)
+ yield self.store.update_cached_last_access_time(
+ local_media, remote_media, self.clock.time_msec()
+ )
- @defer.inlineCallbacks
- def copy_to_backup(self, path):
- """Copy a file from the primary to backup media store, if configured.
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
Args:
- path(str): Relative path to write file to
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
"""
- if self.backup_base_path:
- primary_fname = os.path.join(self.primary_base_path, path)
- backup_fname = os.path.join(self.backup_base_path, path)
-
- # We can either wait for successful writing to the backup repository
- # or write in the background and immediately return
- if self.synchronous_backup_media_store:
- yield make_deferred_yieldable(threads.deferToThread(
- shutil.copyfile, primary_fname, backup_fname,
- ))
- else:
- preserve_fn(threads.deferToThread)(
- shutil.copyfile, primary_fname, backup_fname,
- )
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
@@ -171,10 +142,13 @@ class MediaRepository(object):
"""
media_id = random_string(24)
- fname = yield self.write_to_file_and_backup(
- content, self.filepaths.local_media_filepath_rel(media_id)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
)
+ fname = yield self.media_storage.store_file(content, file_info)
+
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
@@ -185,134 +159,275 @@ class MediaRepository(object):
media_length=content_length,
user_id=auth_user,
)
- media_info = {
- "media_type": media_type,
- "media_length": content_length,
- }
- yield self._generate_thumbnails(None, media_id, media_info)
+ yield self._generate_thumbnails(
+ None, media_id, media_id, media_type,
+ )
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
- def get_remote_media(self, server_name, media_id):
+ def get_local_media(self, request, media_id, name):
+ """Responds to reqests for local media, if exists, or returns 404.
+
+ Args:
+ request(twisted.web.http.Request)
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content.)
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info or media_info["quarantined_by"]:
+ respond_404(request)
+ return
+
+ self.mark_recently_accessed(None, media_id)
+
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ url_cache = media_info["url_cache"]
+
+ file_info = FileInfo(
+ None, media_id,
+ url_cache=url_cache,
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_media(self, request, server_name, media_id, name):
+ """Respond to requests for remote media.
+
+ Args:
+ request(twisted.web.http.Request)
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ self.mark_recently_accessed(server_name, media_id)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
- media_info = yield self._get_remote_media_impl(server_name, media_id)
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # We deliberately stream the file outside the lock
+ if responder:
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+ else:
+ respond_404(request)
+
+ @defer.inlineCallbacks
+ def get_remote_media_info(self, server_name, media_id):
+ """Gets the media info associated with the remote file, downloading
+ if necessary.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[dict]: The media_info of the file
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
+ key = (server_name, media_id)
+ with (yield self.remote_media_linearizer.queue(key)):
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # Ensure we actually use the responder so that it releases resources
+ if responder:
+ with responder:
+ pass
+
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
+ """Looks for media in local cache, if not there then attempt to
+ download from remote server.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[(Responder, media_info)]
+ """
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
- if not media_info:
- media_info = yield self._download_remote_file(
- server_name, media_id
- )
- elif media_info["quarantined_by"]:
- raise NotFoundError()
+
+ # file_id is the ID we use to track the file locally. If we've already
+ # seen the file then reuse the existing ID, otherwise genereate a new
+ # one.
+ if media_info:
+ file_id = media_info["filesystem_id"]
else:
- self.recently_accessed_remotes.add((server_name, media_id))
- yield self.store.update_cached_last_access_time(
- [(server_name, media_id)], self.clock.time_msec()
- )
- defer.returnValue(media_info)
+ file_id = random_string(24)
+
+ file_info = FileInfo(server_name, file_id)
+
+ # If we have an entry in the DB, try and look for it
+ if media_info:
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ raise NotFoundError()
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ defer.returnValue((responder, media_info))
+
+ # Failed to find the file anywhere, lets download it.
+
+ media_info = yield self._download_remote_file(
+ server_name, media_id, file_id
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ defer.returnValue((responder, media_info))
@defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id):
- file_id = random_string(24)
+ def _download_remote_file(self, server_name, media_id, file_id):
+ """Attempt to download the remote file from the given server name,
+ using the given file_id as the local id.
+
+ Args:
+ server_name (str): Originating server
+ media_id (str): The media ID of the content (as defined by the
+ remote server). This is different than the file_id, which is
+ locally generated.
+ file_id (str): Local file ID
+
+ Returns:
+ Deferred[MediaInfo]
+ """
- fpath = self.filepaths.remote_media_filepath_rel(
- server_name, file_id
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
)
- fname = os.path.join(self.primary_base_path, fpath)
- self._makedirs(fname)
- try:
- with open(fname, "wb") as f:
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ request_path = "/".join((
+ "/_matrix/media/v1/download", server_name, media_id,
+ ))
+ try:
+ length, headers = yield self.client.get_file(
+ server_name, request_path, output_stream=f,
+ max_size=self.max_upload_size, args={
+ # tell the remote server to 404 if it doesn't
+ # recognise the server_name, to make sure we don't
+ # end up with a routing loop.
+ "allow_remote": "false",
+ }
+ )
+ except twisted.internet.error.DNSLookupError as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %r",
+ server_name, media_id, e)
+ raise NotFoundError()
+
+ except HttpResponseException as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %s",
+ server_name, media_id, e.response)
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise SynapseError.from_http_response_exception(e)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise
+ except NotRetryingDestination:
+ logger.warn("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
+ except Exception:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ yield finish()
+
+ media_type = headers["Content-Type"][0]
+
+ time_now_ms = self.clock.time_msec()
+
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get("filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith("utf-8''"):
+ upload_name = upload_name_utf8[7:]
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get("filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii
+
+ if upload_name:
+ upload_name = urlparse.unquote(upload_name)
try:
- length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size, args={
- # tell the remote server to 404 if it doesn't
- # recognise the server_name, to make sure we don't
- # end up with a routing loop.
- "allow_remote": "false",
- }
- )
- except twisted.internet.error.DNSLookupError as e:
- logger.warn("HTTP error fetching remote media %s/%s: %r",
- server_name, media_id, e)
- raise NotFoundError()
-
- except HttpResponseException as e:
- logger.warn("HTTP error fetching remote media %s/%s: %s",
- server_name, media_id, e.response)
- if e.code == twisted.web.http.NOT_FOUND:
- raise SynapseError.from_http_response_exception(e)
- raise SynapseError(502, "Failed to fetch remote media")
-
- except SynapseError:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise
- except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
- raise SynapseError(502, "Failed to fetch remote media")
- except Exception:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise SynapseError(502, "Failed to fetch remote media")
-
- yield self.copy_to_backup(fpath)
-
- media_type = headers["Content-Type"][0]
- time_now_ms = self.clock.time_msec()
-
- content_disposition = headers.get("Content-Disposition", None)
- if content_disposition:
- _, params = cgi.parse_header(content_disposition[0],)
- upload_name = None
-
- # First check if there is a valid UTF-8 filename
- upload_name_utf8 = params.get("filename*", None)
- if upload_name_utf8:
- if upload_name_utf8.lower().startswith("utf-8''"):
- upload_name = upload_name_utf8[7:]
-
- # If there isn't check for an ascii name.
- if not upload_name:
- upload_name_ascii = params.get("filename", None)
- if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii
-
- if upload_name:
- upload_name = urlparse.unquote(upload_name)
- try:
- upload_name = upload_name.decode("utf-8")
- except UnicodeDecodeError:
- upload_name = None
- else:
- upload_name = None
-
- logger.info("Stored remote media in file %r", fname)
-
- yield self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
- except Exception:
- os.remove(fname)
- raise
+ upload_name = upload_name.decode("utf-8")
+ except UnicodeDecodeError:
+ upload_name = None
+ else:
+ upload_name = None
+
+ logger.info("Stored remote media in file %r", fname)
+
+ yield self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
media_info = {
"media_type": media_type,
@@ -323,7 +438,7 @@ class MediaRepository(object):
}
yield self._generate_thumbnails(
- server_name, media_id, media_info
+ server_name, media_id, file_id, media_type,
)
defer.returnValue(media_info)
@@ -368,11 +483,18 @@ class MediaRepository(object):
if t_byte_source:
try:
- output_path = yield self.write_to_file_and_backup(
- t_byte_source,
- self.filepaths.local_media_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -400,11 +522,18 @@ class MediaRepository(object):
if t_byte_source:
try:
- output_path = yield self.write_to_file_and_backup(
- t_byte_source,
- self.filepaths.remote_media_thumbnail_rel(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -421,21 +550,22 @@ class MediaRepository(object):
defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
+ def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+ url_cache=False):
"""Generate and store thumbnails for an image.
Args:
- server_name(str|None): The server name if remote media, else None if local
- media_id(str)
- media_info(dict)
- url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
+ server_name (str|None): The server name if remote media, else None if local
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content)
+ file_id (str): Local file ID
+ media_type (str): The content type of the file
+ url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
"""
- media_type = media_info["media_type"]
- file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
@@ -472,20 +602,6 @@ class MediaRepository(object):
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
- # Work out the correct file name for thumbnail
- if server_name:
- file_path = self.filepaths.remote_media_thumbnail_rel(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- elif url_cache:
- file_path = self.filepaths.url_cache_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- file_path = self.filepaths.local_media_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
-
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -505,9 +621,19 @@ class MediaRepository(object):
continue
try:
- # Write to disk
- output_path = yield self.write_to_file_and_backup(
- t_byte_source, file_path,
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -620,7 +746,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
- self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+ self.putChild("thumbnail", ThumbnailResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
- self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+ self.putChild("preview_url", PreviewUrlResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..041ae396cd
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,236 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+from ._base import Responder
+
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ local_media_directory (str): Base path where we store media on disk
+ filepaths (MediaFilePaths)
+ storage_providers ([StorageProvider]): List of StorageProvider that are
+ used to fetch and store files.
+ """
+
+ def __init__(self, local_media_directory, filepaths, storage_providers):
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+
+ @defer.inlineCallbacks
+ def store_file(self, source, file_info):
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info (FileInfo): Info about the file to store
+
+ Returns:
+ Deferred[str]: the file path written to in the primary media store
+ """
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ _write_file_synchronously, source, fname,
+ ))
+
+ defer.returnValue(fname)
+
+ @contextlib.contextmanager
+ def store_into_file(self, file_info):
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns a Deferred.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info (FileInfo): Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ yield finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ finished_called = [False]
+
+ @defer.inlineCallbacks
+ def finish():
+ for provider in self.storage_providers:
+ yield provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ try:
+ with open(fname, "wb") as f:
+ yield f, fname, finish
+ except Exception:
+ t, v, tb = sys.exc_info()
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+ raise t, v, tb
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ @defer.inlineCallbacks
+ def fetch_media(self, file_info):
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[Responder|None]: Returns a Responder if the file was found,
+ otherwise None.
+ """
+
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(FileResponder(open(local_path, "rb")))
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ defer.returnValue(res)
+
+ defer.returnValue(None)
+
+ def _file_info_to_path(self, file_info):
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ str
+ """
+ if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method,
+ )
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id,
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.local_media_filepath_rel(
+ file_info.file_id,
+ )
+
+
+def _write_file_synchronously(source, fname):
+ """Write `source` to the path `fname` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object to be written
+ fname (str): Path to write to
+ """
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ source.seek(0) # Ensure we read from the start of the file
+ with open(fname, "wb") as f:
+ shutil.copyfileobj(source, f)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file (file): A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+ def __init__(self, open_file):
+ self.open_file = open_file
+
+ @defer.inlineCallbacks
+ def write_to_consumer(self, consumer):
+ yield FileSender().beginFileTransfer(self.open_file, consumer)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 40d2e664eb..981f01e417 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
+from ._base import FileInfo
+
from synapse.api.errors import (
SynapseError, Codes,
)
@@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.auth = hs.get_auth()
@@ -62,6 +64,7 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
+ self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@@ -182,8 +185,10 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
+ file_id = media_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
- None, media_info['filesystem_id'], media_info, url_cache=True,
+ None, file_id, file_id, media_info["media_type"],
+ url_cache=True,
)
og = {
@@ -228,8 +233,10 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
+ file_id = image_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
- None, image_info['filesystem_id'], image_info, url_cache=True,
+ None, file_id, file_id, image_info["media_type"],
+ url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -273,19 +280,21 @@ class PreviewUrlResource(Resource):
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
- fpath = self.filepaths.url_cache_filepath_rel(file_id)
- fname = os.path.join(self.primary_base_path, fpath)
- self.media_repo._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=file_id,
+ url_cache=True,
+ )
try:
- with open(fname, "wb") as f:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
# FIXME: pass through 404s and other error messages nicely
- yield self.media_repo.copy_to_backup(fpath)
+ yield finish()
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
@@ -327,7 +336,6 @@ class PreviewUrlResource(Resource):
)
except Exception as e:
- os.remove(fname)
raise SynapseError(
500, ("Failed to download content: %s" % e),
Codes.UNKNOWN
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..c188192f2b
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.config._base import Config
+from synapse.util.logcontext import preserve_fn
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+ """A storage provider is a service that can store uploaded media and
+ retrieve them.
+ """
+ def store_file(self, path, file_info):
+ """Store the file described by file_info. The actual contents can be
+ retrieved by reading the file in file_info.upload_path.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred
+ """
+ pass
+
+ def fetch(self, path, file_info):
+ """Attempt to fetch the file described by file_info and stream it
+ into writer.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred(Responder): Returns a Responder if the provider has the file,
+ otherwise returns None.
+ """
+ pass
+
+
+class StorageProviderWrapper(StorageProvider):
+ """Wraps a storage provider and provides various config options
+
+ Args:
+ backend (StorageProvider)
+ store_local (bool): Whether to store new local files or not.
+ store_synchronous (bool): Whether to wait for file to be successfully
+ uploaded, or todo the upload in the backgroud.
+ store_remote (bool): Whether remote media should be uploaded
+ """
+ def __init__(self, backend, store_local, store_synchronous, store_remote):
+ self.backend = backend
+ self.store_local = store_local
+ self.store_synchronous = store_synchronous
+ self.store_remote = store_remote
+
+ def store_file(self, path, file_info):
+ if not file_info.server_name and not self.store_local:
+ return defer.succeed(None)
+
+ if file_info.server_name and not self.store_remote:
+ return defer.succeed(None)
+
+ if self.store_synchronous:
+ return self.backend.store_file(path, file_info)
+ else:
+ # TODO: Handle errors.
+ preserve_fn(self.backend.store_file)(path, file_info)
+ return defer.succeed(None)
+
+ def fetch(self, path, file_info):
+ return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+ """A storage provider that stores files in a directory on a filesystem.
+
+ Args:
+ hs (HomeServer)
+ config: The config returned by `parse_config`.
+ """
+
+ def __init__(self, hs, config):
+ self.cache_directory = hs.config.media_store_path
+ self.base_directory = config
+
+ def store_file(self, path, file_info):
+ """See StorageProvider.store_file"""
+
+ primary_fname = os.path.join(self.cache_directory, path)
+ backup_fname = os.path.join(self.base_directory, path)
+
+ dirname = os.path.dirname(backup_fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ return threads.deferToThread(
+ shutil.copyfile, primary_fname, backup_fname,
+ )
+
+ def fetch(self, path, file_info):
+ """See StorageProvider.fetch"""
+
+ backup_fname = os.path.join(self.base_directory, path)
+ if os.path.isfile(backup_fname):
+ return FileResponder(open(backup_fname, "rb"))
+
+ @staticmethod
+ def parse_config(config):
+ """Called on startup to parse config supplied. This should parse
+ the config and raise if there is a problem.
+
+ The returned value is passed into the constructor.
+
+ In this case we only care about a single param, the directory, so let's
+ just pull that out.
+ """
+ return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 68d56b2b10..12e84a2b7c 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
# limitations under the License.
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+ parse_media_id, respond_404, respond_with_file, FileInfo,
+ respond_with_responder,
+)
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.store = hs.get_datastore()
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
+ self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
-
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path)
- else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
)
+ t_type = file_info.thumbnail_type
+ t_length = thumbnail_info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
+ else:
+ logger.info("Couldn't find any generated thumbnails")
+ respond_404(request)
+
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
desired_height, desired_method,
desired_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
@@ -141,22 +142,25 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- yield respond_with_file(request, desired_type, file_path)
- return
-
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
@@ -166,21 +170,14 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -195,14 +192,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, desired_width, desired_height,
- desired_type, desired_method,
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
)
- yield respond_with_file(request, desired_type, file_path)
- return
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -213,22 +220,16 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
- # We should proxy the thumbnail from the remote server instead.
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ # We should proxy the thumbnail from the remote server instead of
+ # downloading the remote file and generating our own thumbnails.
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -238,59 +239,23 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- file_id = thumbnail_info["filesystem_id"]
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
- )
-
- @defer.inlineCallbacks
- def _respond_default_thumbnail(self, request, media_info, width, height,
- method, m_type):
- # XXX: how is this meant to work? store.get_default_thumbnails
- # appears to always return [] so won't this always 404?
- media_type = media_info["media_type"]
- top_level_type = media_type.split("/")[0]
- sub_type = media_type.split("/")[-1].split(";")[0]
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, sub_type,
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, "_default",
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- "_default", "_default",
- )
- if not thumbnail_infos:
+ logger.info("Failed to find any generated thumbnails")
respond_404(request)
- return
-
- thumbnail_info = self._select_thumbnail(
- width, height, "crop", m_type, thumbnail_infos
- )
-
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- t_length = thumbnail_info["thumbnail_length"]
-
- file_path = self.filepaths.default_thumbnail(
- top_level_type, sub_type, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):
diff --git a/synapse/server.py b/synapse/server.py
index 99693071b6..ff8a8fbc46 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -307,6 +307,23 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
+ def get_db_conn(self, run_new_connection=True):
+ """Makes a new connection to the database, skipping the db pool
+
+ Returns:
+ Connection: a connection object implementing the PEP-249 spec
+ """
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..4c8247e7c2 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -146,8 +146,20 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
- def get_current_state_ids(self, room_id, event_type=None, state_key="",
- latest_event_ids=None):
+ def get_current_state_ids(self, room_id, latest_event_ids=None):
+ """Get the current state, or the state at a set of events, for a room
+
+ Args:
+ room_id (str):
+
+ latest_event_ids (iterable[str]|None): if given, the forward
+ extremities to resolve. If None, we look them up from the
+ database (via a cache)
+
+ Returns:
+ Deferred[dict[(str, str), str)]]: the state dict, mapping from
+ (event_type, state_key) -> event_id
+ """
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
@@ -155,10 +167,6 @@ class StateHandler(object):
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
- if event_type:
- defer.returnValue(state.get((event_type, state_key)))
- return
-
defer.returnValue(state)
@defer.inlineCallbacks
@@ -341,7 +349,7 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events(
+ new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
@@ -404,7 +412,7 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events(state_set_ids, state_map)
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
@@ -420,19 +428,17 @@ def _ordered_events(events):
return sorted(events, key=key_func)
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- state_map_factory(dict|callable): If callable, then will be called
- with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event. Otherwise, should be
- a dict from event_id to event of all events in state_sets.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent] is a map from
- (type, state_key) to event.
+ dict[(str, str), synapse.events.FrozenEvent]:
+ a map from (type, state_key) to event.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -441,13 +447,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets,
)
- if callable(state_map_factory):
- return _resolve_with_state_fac(
- unconflicted_state, conflicted_state, state_map_factory
- )
-
- state_map = state_map_factory
-
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -491,8 +490,26 @@ def _seperate(state_sets):
@defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
- state_map_factory):
+def resolve_events_with_factory(state_sets, state_map_factory):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
+ a map from (type, state_key) to event.
+ """
+ if len(state_sets) == 1:
+ defer.returnValue(state_sets[0])
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..68125006eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -291,33 +291,33 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Starts a transaction on the database and runs a given function
- start_time = time.time() * 1000
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ current_context = LoggingContext.current_context()
after_callbacks = []
final_callbacks = []
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ return self._new_transaction(
+ conn, desc, after_callbacks, final_callbacks, current_context,
+ func, *args, **kwargs
+ )
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(inner_func, *args, **kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
@@ -329,14 +329,27 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ sched_duration_ms = time.time() * 1000 - start_time
+ sql_scheduling_timer.inc_by(sched_duration_ms)
+ current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index d08f7571d7..33fccfa7a8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
+from synapse.state import resolve_events_with_factory
from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id
@@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
- deferred = ObservableDeferred(defer.Deferred())
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
@@ -146,18 +146,25 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
+ # handle_queue_loop runs in the sentinel logcontext, so
+ # there is no need to preserve_fn when running the
+ # callbacks on the deferred.
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
- except Exception as e:
- item.deferred.errback(e)
+ except Exception:
+ item.deferred.errback()
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
- preserve_fn(handle_queue_loop)()
+ # set handle_queue_loop off on the background. We don't want to
+ # attribute work done in it to the current request, so we drop the
+ # logcontext altogether.
+ with PreserveLoggingContext():
+ handle_queue_loop()
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -528,6 +535,12 @@ class EventsStore(SQLBaseStore):
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
+
+ logger.info(
+ "Resolving state for %s with %i state sets",
+ room_id, len(state_sets),
+ )
+
events_map = {ev.event_id: ev for ev, _ in events_context}
@defer.inlineCallbacks
@@ -550,7 +563,7 @@ class EventsStore(SQLBaseStore):
to_return.update(evs)
defer.returnValue(to_return)
- current_state = yield resolve_events(
+ current_state = yield resolve_events_with_factory(
state_sets,
state_map_factory=get_events,
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index a66ff7c1e0..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause='url_cache IS NOT NULL',
)
- def get_default_thumbnails(self, top_level_type, sub_type):
- return []
-
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
@@ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ """Updates the last access time of the given media
+
+ Args:
+ local_media (iterable[str]): Set of media_ids
+ remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ time_ms: Current time in milliseconds
+ """
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
txn.executemany(sql, (
- (time_ts, media_origin, media_id)
- for media_origin, media_id in origin_id_tuples
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ))
+
+ sql = (
+ "UPDATE local_media_repository SET last_access_ts = ?"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ms, media_id)
+ for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d1691bbac2..c845a0cec5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 46
+SCHEMA_VERSION = 47
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 0604f8f270..9f373b47e0 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -591,7 +591,7 @@ class RoomStore(SQLBaseStore):
"""
UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_origin AND media_id = ?
+ WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index c9bff408ef..dfdcbb3181 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
- join_clause = ""
- where_clause = "?<>''" # naughty hack to keep the same number of binds
+ # make s.user_id null to keep the ordering algorithm happy
+ join_clause = """
+ CROSS JOIN (SELECT NULL as user_id) AS s
+ """
+ join_args = ()
+ where_clause = "1=1"
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
@@ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
+ join_args = (user_id,)
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine):
@@ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
- args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+ args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
@@ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
- args = (user_id, search_query, limit + 1)
+ args = join_args + (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..90a2608d6f
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+
+import Queue
+
+
+class BackgroundFileConsumer(object):
+ """A consumer that writes to a file like object. Supports both push
+ and pull producers
+
+ Args:
+ file_obj (file): The file like object to write to. Closed when
+ finished.
+ """
+
+ # For PushProducers pause if we have this many unwritten slices
+ _PAUSE_ON_QUEUE_SIZE = 5
+ # And resume once the size of the queue is less than this
+ _RESUME_ON_QUEUE_SIZE = 2
+
+ def __init__(self, file_obj):
+ self._file_obj = file_obj
+
+ # Producer we're registered with
+ self._producer = None
+
+ # True if PushProducer, false if PullProducer
+ self.streaming = False
+
+ # For PushProducers, indicates whether we've paused the producer and
+ # need to call resumeProducing before we get more data.
+ self._paused_producer = False
+
+ # Queue of slices of bytes to be written. When producer calls
+ # unregister a final None is sent.
+ self._bytes_queue = Queue.Queue()
+
+ # Deferred that is resolved when finished writing
+ self._finished_deferred = None
+
+ # If the _writer thread throws an exception it gets stored here.
+ self._write_exception = None
+
+ def registerProducer(self, producer, streaming):
+ """Part of IConsumer interface
+
+ Args:
+ producer (IProducer)
+ streaming (bool): True if push based producer, False if pull
+ based.
+ """
+ if self._producer:
+ raise Exception("registerProducer called twice")
+
+ self._producer = producer
+ self.streaming = streaming
+ self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+ if not streaming:
+ self._producer.resumeProducing()
+
+ def unregisterProducer(self):
+ """Part of IProducer interface
+ """
+ self._producer = None
+ if not self._finished_deferred.called:
+ self._bytes_queue.put_nowait(None)
+
+ def write(self, bytes):
+ """Part of IProducer interface
+ """
+ if self._write_exception:
+ raise self._write_exception
+
+ if self._finished_deferred.called:
+ raise Exception("consumer has closed")
+
+ self._bytes_queue.put_nowait(bytes)
+
+ # If this is a PushProducer and the queue is getting behind
+ # then we pause the producer.
+ if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+ self._paused_producer = True
+ self._producer.pauseProducing()
+
+ def _writer(self):
+ """This is run in a background thread to write to the file.
+ """
+ try:
+ while self._producer or not self._bytes_queue.empty():
+ # If we've paused the producer check if we should resume the
+ # producer.
+ if self._producer and self._paused_producer:
+ if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+ reactor.callFromThread(self._resume_paused_producer)
+
+ bytes = self._bytes_queue.get()
+
+ # If we get a None (or empty list) then that's a signal used
+ # to indicate we should check if we should stop.
+ if bytes:
+ self._file_obj.write(bytes)
+
+ # If its a pull producer then we need to explicitly ask for
+ # more stuff.
+ if not self.streaming and self._producer:
+ reactor.callFromThread(self._producer.resumeProducing)
+ except Exception as e:
+ self._write_exception = e
+ raise
+ finally:
+ self._file_obj.close()
+
+ def wait(self):
+ """Returns a deferred that resolves when finished writing to file
+ """
+ return make_deferred_yieldable(self._finished_deferred)
+
+ def _resume_paused_producer(self):
+ """Gets called if we should resume producing after being paused
+ """
+ if self._paused_producer and self._producer:
+ self._paused_producer = False
+ self._producer.resumeProducing()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..94fa7cac98 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
+
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
- "previous_context", "name", "usage_start", "usage_end", "main_thread",
- "__dict__", "tag", "alive",
+ "previous_context", "name", "ru_stime", "ru_utime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "usage_start", "usage_end",
+ "main_thread", "alive",
+ "request", "tag",
]
thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def add_database_scheduled(self, sched_ms):
+ pass
+
def __nonzero__(self):
return False
@@ -94,9 +101,17 @@ class LoggingContext(object):
self.ru_stime = 0.
self.ru_utime = 0.
self.db_txn_count = 0
- self.db_txn_duration = 0.
+
+ # ms spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_ms = 0
+
+ # ms spent waiting for db txns to be scheduled
+ self.db_sched_duration_ms = 0
+
self.usage_start = None
+ self.usage_end = None
self.main_thread = threading.current_thread()
+ self.request = None
self.tag = ""
self.alive = True
@@ -105,7 +120,11 @@ class LoggingContext(object):
@classmethod
def current_context(cls):
- """Get the current logging context from thread local storage"""
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
self.alive = False
def copy_to(self, record):
- """Copy fields from this context to the record"""
- for key, value in self.__dict__.items():
- setattr(record, key, value)
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
- record.ru_utime, record.ru_stime = self.get_resource_usage()
+ # 'request' is the only field we currently use in the logger, so that's
+ # all we need to copy
+ record.request = self.request
def start(self):
if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
self.db_txn_count += 1
- self.db_txn_duration += duration_ms / 1000.
+ self.db_txn_duration_ms += duration_ms
+
+ def add_database_scheduled(self, sched_ms):
+ """Record a use of the database pool
+
+ Args:
+ sched_ms (int): number of milliseconds it took us to get a
+ connection
+ """
+ self.db_sched_duration_ms += sched_ms
class LoggingContextFilter(logging.Filter):
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
-block_timer = metrics.register_distribution(
- "block_timer",
- labels=["block_name"]
+# total number of times we have hit this block
+block_counter = metrics.register_counter(
+ "block_count",
+ labels=["block_name"],
+ alternative_names=(
+ # the following are all deprecated aliases for the same metric
+ metrics.name_prefix + x for x in (
+ "_block_timer:count",
+ "_block_ru_utime:count",
+ "_block_ru_stime:count",
+ "_block_db_txn_count:count",
+ "_block_db_txn_duration:count",
+ )
+ )
+)
+
+block_timer = metrics.register_counter(
+ "block_time_seconds",
+ labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_timer:total",
+ ),
)
-block_ru_utime = metrics.register_distribution(
- "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+ "block_ru_utime_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_ru_utime:total",
+ ),
)
-block_ru_stime = metrics.register_distribution(
- "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+ "block_ru_stime_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_ru_stime:total",
+ ),
)
-block_db_txn_count = metrics.register_distribution(
- "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+ "block_db_txn_count", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_db_txn_count:total",
+ ),
)
-block_db_txn_duration = metrics.register_distribution(
- "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+ "block_db_txn_duration_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_db_txn_duration:total",
+ ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+ "block_db_sched_duration_seconds", labels=["block_name"],
)
@@ -64,7 +101,9 @@ def measure_func(name):
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+ "ru_stime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "created_context",
]
def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+ self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
+
+ block_counter.inc(self.name)
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
context.db_txn_count - self.db_txn_count, self.name
)
block_db_txn_duration.inc_by(
- context.db_txn_duration - self.db_txn_duration, self.name
+ (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+ self.name
+ )
+ block_db_sched_duration.inc_by(
+ (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+ self.name
)
if self.created_context:
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1adedbb361..47b0bb5eb3 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
+ """Raised by the limiter (and federation client) to indicate that we are
+ are deliberately not attempting to contact a given server.
+
+ Args:
+ retry_last_ts (int): the unix ts in milliseconds of our last attempt
+ to contact the server. 0 indicates that the last attempt was
+ successful or that we've never actually attempted to connect.
+ retry_interval (int): the time in milliseconds to wait until the next
+ attempt.
+ destination (str): the domain in question
+ """
+
msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg)
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+ """Checks whether a given format of 3PID is allowed to be used on this HS
+
+ Args:
+ hs (synapse.server.HomeServer): server
+ medium (str): 3pid medium - e.g. email, msisdn
+ address (str): address within that medium (e.g. "wotan@matrix.org")
+ msisdns need to first have been canonicalised
+ Returns:
+ bool: whether the 3PID medium/address is allowed to be added to this HS
+ """
+
+ if hs.config.allowed_local_3pids:
+ for constraint in hs.config.allowed_local_3pids:
+ logger.debug(
+ "Checking 3PID %s (%s) against %s (%s)",
+ address, medium, constraint['pattern'], constraint['medium'],
+ )
+ if (
+ medium == constraint['medium'] and
+ re.match(constraint['pattern'], address)
+ ):
+ return True
+ else:
+ return True
+
+ return False
|