From 4824a33c31c32a055fc5b8ff4d1197c0bd3933c5 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 26 Sep 2017 17:51:26 +0100 Subject: Factor out module loading to a separate place So it can be reused --- synapse/config/password_auth_providers.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83762d089a..90824cab7f 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -15,13 +15,15 @@ from ._base import Config, ConfigError -import importlib +from synapse.util.module_loader import load_module class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] + provider_config = None + # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) @@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config): if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": from ldap_auth_provider import LdapAuthProvider provider_class = LdapAuthProvider + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) else: - # We need to import the module, and then pick the class out of - # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) - module = importlib.import_module(module) - provider_class = getattr(module, clz) + (provider_class, provider_config) = load_module(provider) - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) self.password_providers.append((provider_class, provider_config)) def default_config(self, **kwargs): -- cgit 1.5.1 From 6cd5fcd5366cfef4959d107e818d0e20d78aa483 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 26 Sep 2017 19:20:23 +0100 Subject: Make the spam checker a module --- synapse/config/homeserver.py | 4 +++- synapse/events/spamcheck.py | 37 +++++++++++++++++++---------------- synapse/federation/federation_base.py | 5 ++--- synapse/handlers/message.py | 5 +++-- synapse/server.py | 5 +++++ 5 files changed, 33 insertions(+), 23 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b22cacf8dc..3f9d9d5f8b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -34,6 +34,7 @@ from .password_auth_providers import PasswordAuthProviderConfig from .emailconfig import EmailConfig from .workers import WorkerConfig from .push import PushConfig +from .spam_checker import SpamCheckerConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -41,7 +42,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig,): + WorkerConfig, PasswordAuthProviderConfig, PushConfig, + SpamCheckerConfig,): pass diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 56fa9e556e..7b22b3413a 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -13,26 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +class SpamChecker(object): + def __init__(self, hs): + self.spam_checker = None -def check_event_for_spam(event): - """Checks if a given event is considered "spammy" by this server. + if hs.config.spam_checker is not None: + module, config = hs.config.spam_checker + print("cfg %r", config) + self.spam_checker = module(config=config) - If the server considers an event spammy, then it will be rejected if - sent by a local user. If it is sent by a user on another server, then - users receive a blank event. + def check_event_for_spam(self, event): + """Checks if a given event is considered "spammy" by this server. - Args: - event (synapse.events.EventBase): the event to be checked + If the server considers an event spammy, then it will be rejected if + sent by a local user. If it is sent by a user on another server, then + users receive a blank event. - Returns: - bool: True if the event is spammy. - """ - if not hasattr(event, "content") or "body" not in event.content: - return False + Args: + event (synapse.events.EventBase): the event to be checked - # for example: - # - # if "the third flower is green" in event.content["body"]: - # return True + Returns: + bool: True if the event is spammy. + """ + if self.spam_checker is None: + return False - return False + return self.spam_checker.check_event_for_spam(event) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index babd9ea078..a0f5d40eb3 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -16,7 +16,6 @@ import logging from synapse.api.errors import SynapseError from synapse.crypto.event_signing import check_event_content_hash -from synapse.events import spamcheck from synapse.events.utils import prune_event from synapse.util import unwrapFirstError, logcontext from twisted.internet import defer @@ -26,7 +25,7 @@ logger = logging.getLogger(__name__) class FederationBase(object): def __init__(self, hs): - pass + self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, @@ -144,7 +143,7 @@ class FederationBase(object): ) return redacted - if spamcheck.check_event_for_spam(pdu): + if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da18bf23db..37f0a2772a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.events import spamcheck from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -58,6 +57,8 @@ class MessageHandler(BaseHandler): self.action_generator = hs.get_action_generator() + self.spam_checker = hs.get_spam_checker() + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -322,7 +323,7 @@ class MessageHandler(BaseHandler): txn_id=txn_id ) - if spamcheck.check_event_for_spam(event): + if self.spam_checker.check_event_for_spam(event): raise SynapseError( 403, "Spam is not permitted here", Codes.FORBIDDEN ) diff --git a/synapse/server.py b/synapse/server.py index a38e5179e0..4d44af745e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory +from synapse.events.spamcheck import SpamChecker from synapse.federation import initialize_http_replication from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.transport.client import TransportLayerClient @@ -139,6 +140,7 @@ class HomeServer(object): 'read_marker_handler', 'action_generator', 'user_directory_handler', + 'spam_checker', ] def __init__(self, hostname, **kwargs): @@ -309,6 +311,9 @@ class HomeServer(object): def build_user_directory_handler(self): return UserDirectoyHandler(self) + def build_spam_checker(self): + return SpamChecker(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) -- cgit 1.5.1 From 1786b0e768877a608a6f44a6a37cc36e598eda4e Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 27 Sep 2017 10:22:54 +0100 Subject: Forgot the new file again :( --- synapse/config/spam_checker.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 synapse/config/spam_checker.py (limited to 'synapse/config') diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py new file mode 100644 index 0000000000..3fec42bdb0 --- /dev/null +++ b/synapse/config/spam_checker.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.module_loader import load_module + +from ._base import Config + + +class SpamCheckerConfig(Config): + def read_config(self, config): + self.spam_checker = None + + provider = config.get("spam_checker", None) + if provider is not None: + self.spam_checker = load_module(provider) + + def default_config(self, **kwargs): + return """\ + # spam_checker: + # module: "my_custom_project.SuperSpamChecker" + # config: + # example_option: 'things' + """ -- cgit 1.5.1 From bf4fb1fb400daad23702bc0b3231ec069d68e87e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Oct 2017 15:20:59 +0100 Subject: Basic implementation of backup media store --- synapse/config/repository.py | 18 +++ synapse/rest/media/v1/media_repository.py | 221 ++++++++++++++---------------- synapse/rest/media/v1/thumbnailer.py | 16 +-- synapse/rest/media/v1/upload_resource.py | 2 +- 4 files changed, 131 insertions(+), 126 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2c6f57168e..e3c83d56fa 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -70,7 +70,17 @@ class ContentRepositoryConfig(Config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_spider_size = self.parse_size(config["max_spider_size"]) + self.media_store_path = self.ensure_directory(config["media_store_path"]) + + self.backup_media_store_path = config.get("backup_media_store_path") + if self.backup_media_store_path: + self.ensure_directory(self.backup_media_store_path) + + self.synchronous_backup_media_store = config.get( + "synchronous_backup_media_store", False + ) + self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( @@ -115,6 +125,14 @@ class ContentRepositoryConfig(Config): # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" + # A secondary directory where uploaded images and attachments are + # stored as a backup. + # backup_media_store_path: "%(media_store)s" + + # Whether to wait for successful write to backup media store before + # returning successfully. + # synchronous_backup_media_store: false + # Directory where in-progress uploads are stored. uploads_path: "%(uploads_path)s" diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0ea1248ce6..3b442cc16b 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.retryutils import NotRetryingDestination import os @@ -59,7 +59,12 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels + self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.backup_filepaths = None + if hs.config.backup_media_store_path: + self.backup_filepaths = MediaFilePaths(hs.config.backup_media_store_path) + self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements @@ -87,18 +92,43 @@ class MediaRepository(object): if not os.path.exists(dirname): os.makedirs(dirname) + @defer.inlineCallbacks + def _write_to_file(self, source, file_name_func): + def write_file_thread(file_name): + source.seek(0) # Ensure we read from the start of the file + with open(file_name, "wb") as f: + shutil.copyfileobj(source, f) + + fname = file_name_func(self.filepaths) + self._makedirs(fname) + + # Write to the main repository + yield preserve_context_over_fn(threads.deferToThread, write_file_thread, fname) + + # Write to backup repository + if self.backup_filepaths: + backup_fname = file_name_func(backup_filepaths) + self._makedirs(backup_fname) + + # We can either wait for successful writing to the backup repository + # or write in the background and immediately return + if hs.config.synchronous_backup_media_store: + yield preserve_context_over_fn( + threads.deferToThread, write_file_thread, backup_fname, + ) + else: + preserve_fn(threads.deferToThread)(write_file, backup_fname) + + defer.returnValue(fname) + @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): media_id = random_string(24) - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) - - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) + fname = yield self._write_to_file( + content, lambda f: f.local_media_filepath(media_id) + ) logger.info("Stored local media in file %r", fname) @@ -253,9 +283,8 @@ class MediaRepository(object): def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): - thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height @@ -267,36 +296,40 @@ class MediaRepository(object): return if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: - t_len = None + t_byte_source = None - return t_len + return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type + thumbnailer, t_width, t_height, t_method, t_type ) - if t_len: + if t_byte_source: + output_path = yield self._write_to_file( + content, + lambda f: f.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + ) + logger.info("Stored thumbnail in file %r", output_path) + yield self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len + media_id, t_width, t_height, t_type, t_method, len(t_byte_source.getvalue()) ) defer.returnValue(t_path) @@ -306,21 +339,25 @@ class MediaRepository(object): t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type + thumbnailer, t_width, t_height, t_method, t_type ) - if t_len: + if t_byte_source: + output_path = yield self._write_to_file( + content, + lambda f: f.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + ) + logger.info("Stored thumbnail in file %r", output_path) + yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + t_width, t_height, t_type, t_method, len(t_byte_source.getvalue()) ) defer.returnValue(t_path) @@ -351,59 +388,32 @@ class MediaRepository(object): local_thumbnails = [] def generate_thumbnails(): - scales = set() - crops = set() for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) + t_byte_source = self._generate_thumbnail( + thumbnailer, r_width, r_height, r_method, r_type, + ) - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + r_width, r_height, r_method, r_type, t_byte_source )) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) + for t_width, t_height, t_method, t_type, t_byte_source in local_thumbnails: + if url_cache: + path_name_func = lambda f: f.url_cache_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + else: + path_name_func = lambda f: f.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + + yield self._write_to_file(t_byte_source, path_name_func) + + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, len(t_byte_source.getvalue()) + ) defer.returnValue({ "width": m_width, @@ -433,51 +443,32 @@ class MediaRepository(object): ) return - scales = set() - crops = set() for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method + t_byte_source = self._generate_thumbnail( + thumbnailer, r_width, r_height, r_method, r_type, ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) + + remote_thumbnails.append(( + r_width, r_height, r_method, r_type, t_byte_source + )) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for r in remote_thumbnails: yield self.store.store_remote_media_thumbnail(*r) + for t_width, t_height, t_method, t_type, t_byte_source in local_thumbnails: + path_name_func = lambda f: f.remote_media_thumbnail( + server_name, media_id, file_id, t_width, t_height, t_type, t_method + ) + + yield self._write_to_file(t_byte_source, path_name_func) + + yield self.store.store_remote_media_thumbnail( + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, len(t_byte_source.getvalue()) + ) + defer.returnValue({ "width": m_width, "height": m_height, diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3868d4f65f..60498b08aa 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -50,12 +50,12 @@ class Thumbnailer(object): else: return ((max_height * self.width) // self.height, max_height) - def scale(self, output_path, width, height, output_type): + def scale(self, width, height, output_type): """Rescales the image to the given dimensions""" scaled = self.image.resize((width, height), Image.ANTIALIAS) - return self.save_image(scaled, output_type, output_path) + return self._encode_image(scaled, output_type) - def crop(self, output_path, width, height, output_type): + def crop(self, width, height, output_type): """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -82,13 +82,9 @@ class Thumbnailer(object): crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self.save_image(cropped, output_type, output_path) + return self._encode_image(cropped, output_type) - def save_image(self, output_image, output_type, output_path): + def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) - output_bytes = output_bytes_io.getvalue() - with open(output_path, "wb") as output_file: - output_file.write(output_bytes) - logger.info("Stored thumbnail in file %r", output_path) - return len(output_bytes) + return output_bytes_io diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 4ab33f73bf..f6f498cdc5 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -93,7 +93,7 @@ class UploadResource(Resource): # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content.read(), + media_type, upload_name, request.content, content_length, requester.user ) -- cgit 1.5.1 From e283b555b1f20de4fd393fd947e82eb3c635b7e9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Oct 2017 17:31:24 +0100 Subject: Copy everything to backup --- synapse/config/repository.py | 4 +- synapse/rest/media/v1/filepath.py | 99 +++++++++++++++-------- synapse/rest/media/v1/media_repository.py | 109 ++++++++++++++++---------- synapse/rest/media/v1/preview_url_resource.py | 7 +- synapse/rest/media/v1/thumbnailer.py | 9 ++- 5 files changed, 151 insertions(+), 77 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/repository.py b/synapse/config/repository.py index e3c83d56fa..6baa474931 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -75,7 +75,9 @@ class ContentRepositoryConfig(Config): self.backup_media_store_path = config.get("backup_media_store_path") if self.backup_media_store_path: - self.ensure_directory(self.backup_media_store_path) + self.backup_media_store_path = self.ensure_directory( + self.backup_media_store_path + ) self.synchronous_backup_media_store = config.get( "synchronous_backup_media_store", False diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d5cec10127..43d0eea00d 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -15,103 +15,134 @@ import os import re +import functools NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") +def _wrap_in_base_path(func): + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + @functools.wraps(func) + def _wrapped(self, *args, **kwargs): + path = func(self, *args, **kwargs) + return os.path.join(self.primary_base_path, path) + + return _wrapped + + class MediaFilePaths(object): + """Describes where files are stored on disk. - def __init__(self, base_path): - self.base_path = base_path + Most of the function have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ - def default_thumbnail(self, default_top_level, default_sub_type, width, - height, content_type, method): + def __init__(self, primary_base_path): + self.primary_base_path = primary_base_path + + def default_thumbnail_rel(self, default_top_level, default_sub_type, width, + height, content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "default_thumbnails", default_top_level, + "default_thumbnails", default_top_level, default_sub_type, file_name ) - def local_media_filepath(self, media_id): + default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) + + def local_media_filepath_rel(self, media_id): return os.path.join( - self.base_path, "local_content", + "local_content", media_id[0:2], media_id[2:4], media_id[4:] ) - def local_media_thumbnail(self, media_id, width, height, content_type, - method): + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + def local_media_thumbnail_rel(self, media_id, width, height, content_type, + method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "local_thumbnails", + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) - def remote_media_filepath(self, server_name, file_id): + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - self.base_path, "remote_content", server_name, + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) - def remote_media_thumbnail(self, server_name, file_id, width, height, - content_type, method): + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + def remote_media_thumbnail_rel(self, server_name, file_id, width, height, + content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - self.base_path, "remote_thumbnail", server_name, + "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], file_name ) + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( - self.base_path, "remote_thumbnail", server_name, + "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], ) - def url_cache_filepath(self, media_id): + def url_cache_filepath_rel(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf return os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[:10], media_id[11:] ) else: return os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[0:2], media_id[2:4], media_id[4:], ) + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + def url_cache_filepath_dirs_to_delete(self, media_id): "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[:10], ), ] else: return [ os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[0:2], media_id[2:4], ), os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[0:2], ), ] - def url_cache_thumbnail(self, media_id, width, height, content_type, - method): + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, + method): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -122,29 +153,31 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[:10], media_id[11:], file_name ) else: return os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + def url_cache_thumbnail_directory(self, media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[:10], media_id[11:], ) else: return os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], ) @@ -155,26 +188,26 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[:10], media_id[11:], ), os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[:10], ), ] else: return [ os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], ), os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[0:2], media_id[2:4], ), os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[0:2], ), ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 93b35af9cf..398e973ca9 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -60,10 +60,12 @@ class MediaRepository(object): self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = MediaFilePaths(hs.config.media_store_path) - self.backup_filepaths = None + self.primary_base_path = hs.config.media_store_path + self.filepaths = MediaFilePaths(self.primary_base_path) + + self.backup_base_path = None if hs.config.backup_media_store_path: - self.backup_filepaths = MediaFilePaths(hs.config.backup_media_store_path) + self.backup_base_path = hs.config.backup_media_store_path self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store @@ -94,42 +96,63 @@ class MediaRepository(object): if not os.path.exists(dirname): os.makedirs(dirname) - @defer.inlineCallbacks - def _write_to_file(self, source, file_name_func): - def write_file_thread(file_name): - source.seek(0) # Ensure we read from the start of the file - with open(file_name, "wb") as f: - shutil.copyfileobj(source, f) + @staticmethod + def write_file_synchronously(source, fname): + source.seek(0) # Ensure we read from the start of the file + with open(fname, "wb") as f: + shutil.copyfileobj(source, f) - fname = file_name_func(self.filepaths) + @defer.inlineCallbacks + def write_to_file(self, source, path): + """Write `source` to the on disk media store, and also the backup store + if configured. + + Args: + source: A file like object that should be written + path: Relative path to write file to + + Returns: + string: the file path written to in the primary media store + """ + fname = os.path.join(self.primary_base_path, path) self._makedirs(fname) # Write to the main repository - yield preserve_context_over_fn(threads.deferToThread, write_file_thread, fname) + yield preserve_context_over_fn( + threads.deferToThread, + self.write_file_synchronously, source, fname, + ) # Write to backup repository - if self.backup_filepaths: - backup_fname = file_name_func(self.backup_filepaths) + yield self.copy_to_backup(source, path) + + defer.returnValue(fname) + + @defer.inlineCallbacks + def copy_to_backup(self, source, path): + if self.backup_base_path: + backup_fname = os.path.join(self.backup_base_path, path) self._makedirs(backup_fname) # We can either wait for successful writing to the backup repository # or write in the background and immediately return if self.synchronous_backup_media_store: yield preserve_context_over_fn( - threads.deferToThread, write_file_thread, backup_fname, + threads.deferToThread, + self.write_file_synchronously, source, backup_fname, ) else: - preserve_fn(threads.deferToThread)(write_file_thread, backup_fname) - - defer.returnValue(fname) + preserve_fn(threads.deferToThread)( + self.write_file_synchronously, source, backup_fname, + ) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): media_id = random_string(24) - fname = yield self._write_to_file( - content, lambda f: f.local_media_filepath(media_id) + fname = yield self.write_to_file( + content, self.filepaths.local_media_filepath_rel(media_id) ) logger.info("Stored local media in file %r", fname) @@ -180,9 +203,10 @@ class MediaRepository(object): def _download_remote_file(self, server_name, media_id): file_id = random_string(24) - fname = self.filepaths.remote_media_filepath( + fpath = self.filepaths.remote_media_filepath_rel( server_name, file_id ) + fname = os.path.join(self.primary_base_path, fpath) self._makedirs(fname) try: @@ -224,6 +248,9 @@ class MediaRepository(object): server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") + with open(fname) as f: + yield self.copy_to_backup(f, fpath) + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() @@ -322,15 +349,15 @@ class MediaRepository(object): ) if t_byte_source: - output_path = yield self._write_to_file( + output_path = yield self.write_to_file( t_byte_source, - lambda f: f.local_media_thumbnail( + self.filepaths.local_media_thumbnail_rel( media_id, t_width, t_height, t_type, t_method ) ) logger.info("Stored thumbnail in file %r", output_path) - yield self.store.store_local_thumbnail( + yield self.store.store_local_thumbnail_rel( media_id, t_width, t_height, t_type, t_method, len(t_byte_source.getvalue()) ) @@ -350,15 +377,15 @@ class MediaRepository(object): ) if t_byte_source: - output_path = yield self._write_to_file( + output_path = yield self.write_to_file( t_byte_source, - lambda f: f.remote_media_thumbnail( + self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method ) ) logger.info("Stored thumbnail in file %r", output_path) - yield self.store.store_remote_media_thumbnail( + yield self.store.store_remote_media_thumbnail_rel( server_name, media_id, file_id, t_width, t_height, t_type, t_method, len(t_byte_source.getvalue()) ) @@ -403,17 +430,16 @@ class MediaRepository(object): yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for t_width, t_height, t_method, t_type, t_byte_source in local_thumbnails: - def path_name_func(f): - if url_cache: - return f.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - return f.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) + if url_cache: + file_path = self.filepaths.url_cache_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) + else: + file_path = self.filepaths.local_media_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) - yield self._write_to_file(t_byte_source, path_name_func) + yield self.write_to_file(t_byte_source, file_path) yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, @@ -460,12 +486,11 @@ class MediaRepository(object): yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for t_width, t_height, t_method, t_type, t_byte_source in remote_thumbnails: - def path_name_func(f): - return f.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) + file_path = self.filepaths.remote_media_thumbnail_rel( + server_name, file_id, t_width, t_height, t_type, t_method + ) - yield self._write_to_file(t_byte_source, path_name_func) + yield self.write_to_file(t_byte_source, file_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -491,6 +516,8 @@ class MediaRepository(object): logger.info("Deleting: %r", key) + # TODO: Should we delete from the backup store + with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 895b480d5c..f82b8fbc51 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -59,6 +59,7 @@ class PreviewUrlResource(Resource): self.store = hs.get_datastore() self.client = SpiderHttpClient(hs) self.media_repo = media_repo + self.primary_base_path = media_repo.primary_base_path self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist @@ -262,7 +263,8 @@ class PreviewUrlResource(Resource): file_id = datetime.date.today().isoformat() + '_' + random_string(16) - fname = self.filepaths.url_cache_filepath(file_id) + fpath = self.filepaths.url_cache_filepath_rel(file_id) + fname = os.path.join(self.primary_base_path, fpath) self.media_repo._makedirs(fname) try: @@ -273,6 +275,9 @@ class PreviewUrlResource(Resource): ) # FIXME: pass through 404s and other error messages nicely + with open(fname) as f: + yield self.media_repo.copy_to_backup(f, fpath) + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 60498b08aa..e1ee535b9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -51,7 +51,11 @@ class Thumbnailer(object): return ((max_height * self.width) // self.height, max_height) def scale(self, width, height, output_type): - """Rescales the image to the given dimensions""" + """Rescales the image to the given dimensions. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk + """ scaled = self.image.resize((width, height), Image.ANTIALIAS) return self._encode_image(scaled, output_type) @@ -65,6 +69,9 @@ class Thumbnailer(object): Args: max_width: The largest possible width. max_height: The larget possible height. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width -- cgit 1.5.1 From c05e6015cc522b838ab2db6bffd494b017cf8ec6 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 16 Oct 2017 17:57:27 +0100 Subject: Add config option to auto-join new users to rooms New users who register on the server will be dumped into all rooms in auto_join_rooms in the config. --- synapse/config/registration.py | 6 ++++++ synapse/rest/client/v2_alpha/register.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) (limited to 'synapse/config') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f7e03c4cde..9e2a6d1ae5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -41,6 +41,8 @@ class RegistrationConfig(Config): self.allow_guest_access and config.get("invite_3pid_guest", False) ) + self.auto_join_rooms = config.get("auto_join_rooms", []) + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -70,6 +72,10 @@ class RegistrationConfig(Config): - matrix.org - vector.im - riot.im + + # Users who register on this homeserver will automatically be joined to these rooms + #auto_join_rooms: + # - "#example:example.com" """ % locals() def add_arguments(self, parser): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1421c18152..d9a8cdbbb5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -17,8 +17,10 @@ from twisted.internet import defer import synapse +import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType +from synapse.types import RoomID, RoomAlias from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string @@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.room_member_handler = hs.get_handlers().room_member_handler self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() @@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet): generate_token=False, ) + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = synapse.types.create_requester(registered_user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) self.auth_handler.set_session_data( @@ -372,6 +383,29 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} + @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield self.room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + ) + @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( -- cgit 1.5.1 From a9c2e930ac455db13ce37a90b3ff0d93c0d1ea43 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 17 Oct 2017 10:13:13 +0100 Subject: pep8 --- synapse/config/registration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/config') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9e2a6d1ae5..ef917fc9f2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -73,7 +73,8 @@ class RegistrationConfig(Config): - vector.im - riot.im - # Users who register on this homeserver will automatically be joined to these rooms + # Users who register on this homeserver will automatically be joined + # to these rooms #auto_join_rooms: # - "#example:example.com" """ % locals() -- cgit 1.5.1 From 7216c76654bbb57bc0ebc27498de7eda247f8ffb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 Oct 2017 14:46:17 +0100 Subject: Improve error handling for missing files (#2551) `os.path.exists` doesn't allow us to distinguish between permissions errors and the path actually not existing, which repeatedly confuses people. It also means that we try to overwrite existing key files, which is super-confusing. (cf issues #2455, #2379). Use os.stat instead. Also, don't recomemnd the the use of --generate-config, which screws everything up if you're using debian (cf #2455). --- synapse/config/_base.py | 36 ++++++++++++++++++++++++++---------- synapse/config/key.py | 8 ++++---- synapse/config/tls.py | 6 +++--- 3 files changed, 33 insertions(+), 17 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1ab5593c6e..fa105bce72 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -81,22 +81,38 @@ class Config(object): def abspath(file_path): return os.path.abspath(file_path) if file_path else file_path + @classmethod + def path_exists(cls, file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + @classmethod def check_file(cls, file_path, config_name): if file_path is None: raise ConfigError( "Missing config for %s." - " You must specify a path for the config file. You can " - "do this with the -c or --config-path option. " - "Adding --generate-config along with --server-name " - " will generate a config file at the given path." % (config_name,) ) - if not os.path.exists(file_path): + try: + os.stat(file_path) + except OSError as e: raise ConfigError( - "File %s config for %s doesn't exist." - " Try running again with --generate-config" - % (file_path, config_name,) + "Error accessing file '%s' (config for %s): %s" + % (file_path, config_name, e.strerror) ) return cls.abspath(file_path) @@ -248,7 +264,7 @@ class Config(object): " -c CONFIG-FILE\"" ) (config_path,) = config_files - if not os.path.exists(config_path): + if not cls.path_exists(config_path): if config_args.keys_directory: config_dir_path = config_args.keys_directory else: @@ -261,7 +277,7 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) - if not os.path.exists(config_dir_path): + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( diff --git a/synapse/config/key.py b/synapse/config/key.py index 6ee643793e..4b8fc063d0 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -118,10 +118,9 @@ class KeyConfig(Config): signing_keys = self.read_file(signing_key_path, "signing_key") try: return read_signing_keys(signing_keys.splitlines(True)) - except Exception: + except Exception as e: raise ConfigError( - "Error reading signing_key." - " Try running again with --generate-config" + "Error reading signing_key: %s" % (str(e)) ) def read_old_signing_keys(self, old_signing_keys): @@ -141,7 +140,8 @@ class KeyConfig(Config): def generate_files(self, config): signing_key_path = config["signing_key_path"] - if not os.path.exists(signing_key_path): + + if not self.path_exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) write_signing_keys( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e081840a83..247f18f454 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -126,7 +126,7 @@ class TlsConfig(Config): tls_private_key_path = config["tls_private_key_path"] tls_dh_params_path = config["tls_dh_params_path"] - if not os.path.exists(tls_private_key_path): + if not self.path_exists(tls_private_key_path): with open(tls_private_key_path, "w") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) @@ -141,7 +141,7 @@ class TlsConfig(Config): crypto.FILETYPE_PEM, private_key_pem ) - if not os.path.exists(tls_certificate_path): + if not self.path_exists(tls_certificate_path): with open(tls_certificate_path, "w") as certificate_file: cert = crypto.X509() subject = cert.get_subject() @@ -159,7 +159,7 @@ class TlsConfig(Config): certificate_file.write(cert_pem) - if not os.path.exists(tls_dh_params_path): + if not self.path_exists(tls_dh_params_path): if GENERATE_DH_PARAMS: subprocess.check_call([ "openssl", "dhparam", -- cgit 1.5.1 From 29bafe2f7e82e48b9aad03fb23a790b3719faf78 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Oct 2017 12:13:44 +0100 Subject: Add config to enable group creation --- synapse/config/homeserver.py | 3 ++- synapse/groups/groups_server.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 3f9d9d5f8b..05e242aef6 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -35,6 +35,7 @@ from .emailconfig import EmailConfig from .workers import WorkerConfig from .push import PushConfig from .spam_checker import SpamCheckerConfig +from .groups import GroupsConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -43,7 +44,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig,): + SpamCheckerConfig, GroupsConfig,): pass diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index e9b44c0971..c19d733d76 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -704,10 +704,18 @@ class GroupsServerHandler(object): if group: raise SynapseError(400, "Group already exists") - # TODO: Add config to enforce that only server admins can create rooms is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) if not is_admin: - raise SynapseError(403, "Only server admin can create group on this server") + if not self.hs.config.enable_group_creation: + raise SynapseError(403, "Only server admin can create group on this server") + localpart = GroupID.from_string(group_id).localpart + if not localpart.startswith(self.hs.config.group_creation_prefix): + raise SynapseError( + 400, + "Can only create groups with prefix %r on this server" % ( + self.hs.config.group_creation_prefix, + ), + ) profile = content.get("profile", {}) name = profile.get("name") -- cgit 1.5.1 From ffd3f1a7838eb12a30abe40831275f360b528d1f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Oct 2017 12:17:30 +0100 Subject: Add missing file... --- synapse/config/groups.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 synapse/config/groups.py (limited to 'synapse/config') diff --git a/synapse/config/groups.py b/synapse/config/groups.py new file mode 100644 index 0000000000..7683a37534 --- /dev/null +++ b/synapse/config/groups.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.module_loader import load_module + +from ._base import Config + +from distutils.util import strtobool + + +class GroupsConfig(Config): + def read_config(self, config): + self.enable_group_creation = config.get("enable_group_creation", False) + self.group_creation_prefix = config.get("group_creation_prefix", "") + + def default_config(self, **kwargs): + return """\ + # Whether to allow non server admins to create groups on this server + enable_group_creation: false + + # If enabled, non server admins can only create groups with local parts + # starting with this prefix + # group_creation_prefix: "unofficial/" + """ -- cgit 1.5.1 From c7d46510d7a700fda9730d90b010e1e1e596c58e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Oct 2017 13:36:06 +0100 Subject: Flake8 --- synapse/config/groups.py | 4 ---- synapse/groups/groups_server.py | 6 ++++-- 2 files changed, 4 insertions(+), 6 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/groups.py b/synapse/config/groups.py index 7683a37534..997fa2881f 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -13,12 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.module_loader import load_module - from ._base import Config -from distutils.util import strtobool - class GroupsConfig(Config): def read_config(self, config): diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index c19d733d76..fc4edb7f04 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -707,7 +707,9 @@ class GroupsServerHandler(object): is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) if not is_admin: if not self.hs.config.enable_group_creation: - raise SynapseError(403, "Only server admin can create group on this server") + raise SynapseError( + 403, "Only server admin can create group on this server", + ) localpart = GroupID.from_string(group_id).localpart if not localpart.startswith(self.hs.config.group_creation_prefix): raise SynapseError( @@ -715,7 +717,7 @@ class GroupsServerHandler(object): "Can only create groups with prefix %r on this server" % ( self.hs.config.group_creation_prefix, ), - ) + ) profile = content.get("profile", {}) name = profile.get("name") -- cgit 1.5.1