diff options
-rw-r--r-- | synapse/config/_base.py | 36 | ||||
-rw-r--r-- | synapse/config/key.py | 8 | ||||
-rw-r--r-- | synapse/config/registration.py | 7 | ||||
-rw-r--r-- | synapse/config/tls.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 34 | ||||
-rw-r--r-- | synapse/storage/events.py | 24 | ||||
-rw-r--r-- | synapse/util/async.py | 5 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 1 | ||||
-rw-r--r-- | tests/storage/event_injector.py | 76 | ||||
-rw-r--r-- | tests/util/test_logcontext.py (renamed from tests/util/test_log_context.py) | 38 |
10 files changed, 135 insertions, 100 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1ab5593c6e..fa105bce72 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -82,21 +82,37 @@ class Config(object): return os.path.abspath(file_path) if file_path else file_path @classmethod + def path_exists(cls, file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + @classmethod def check_file(cls, file_path, config_name): if file_path is None: raise ConfigError( "Missing config for %s." - " You must specify a path for the config file. You can " - "do this with the -c or --config-path option. " - "Adding --generate-config along with --server-name " - "<server name> will generate a config file at the given path." % (config_name,) ) - if not os.path.exists(file_path): + try: + os.stat(file_path) + except OSError as e: raise ConfigError( - "File %s config for %s doesn't exist." - " Try running again with --generate-config" - % (file_path, config_name,) + "Error accessing file '%s' (config for %s): %s" + % (file_path, config_name, e.strerror) ) return cls.abspath(file_path) @@ -248,7 +264,7 @@ class Config(object): " -c CONFIG-FILE\"" ) (config_path,) = config_files - if not os.path.exists(config_path): + if not cls.path_exists(config_path): if config_args.keys_directory: config_dir_path = config_args.keys_directory else: @@ -261,7 +277,7 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) - if not os.path.exists(config_dir_path): + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( diff --git a/synapse/config/key.py b/synapse/config/key.py index 6ee643793e..4b8fc063d0 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -118,10 +118,9 @@ class KeyConfig(Config): signing_keys = self.read_file(signing_key_path, "signing_key") try: return read_signing_keys(signing_keys.splitlines(True)) - except Exception: + except Exception as e: raise ConfigError( - "Error reading signing_key." - " Try running again with --generate-config" + "Error reading signing_key: %s" % (str(e)) ) def read_old_signing_keys(self, old_signing_keys): @@ -141,7 +140,8 @@ class KeyConfig(Config): def generate_files(self, config): signing_key_path = config["signing_key_path"] - if not os.path.exists(signing_key_path): + + if not self.path_exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) write_signing_keys( diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f7e03c4cde..ef917fc9f2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -41,6 +41,8 @@ class RegistrationConfig(Config): self.allow_guest_access and config.get("invite_3pid_guest", False) ) + self.auto_join_rooms = config.get("auto_join_rooms", []) + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -70,6 +72,11 @@ class RegistrationConfig(Config): - matrix.org - vector.im - riot.im + + # Users who register on this homeserver will automatically be joined + # to these rooms + #auto_join_rooms: + # - "#example:example.com" """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e081840a83..247f18f454 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -126,7 +126,7 @@ class TlsConfig(Config): tls_private_key_path = config["tls_private_key_path"] tls_dh_params_path = config["tls_dh_params_path"] - if not os.path.exists(tls_private_key_path): + if not self.path_exists(tls_private_key_path): with open(tls_private_key_path, "w") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) @@ -141,7 +141,7 @@ class TlsConfig(Config): crypto.FILETYPE_PEM, private_key_pem ) - if not os.path.exists(tls_certificate_path): + if not self.path_exists(tls_certificate_path): with open(tls_certificate_path, "w") as certificate_file: cert = crypto.X509() subject = cert.get_subject() @@ -159,7 +159,7 @@ class TlsConfig(Config): certificate_file.write(cert_pem) - if not os.path.exists(tls_dh_params_path): + if not self.path_exists(tls_dh_params_path): if GENERATE_DH_PARAMS: subprocess.check_call([ "openssl", "dhparam", diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1421c18152..d9a8cdbbb5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -17,8 +17,10 @@ from twisted.internet import defer import synapse +import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType +from synapse.types import RoomID, RoomAlias from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string @@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.room_member_handler = hs.get_handlers().room_member_handler self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() @@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet): generate_token=False, ) + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = synapse.types.create_requester(registered_user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) self.auth_handler.set_session_data( @@ -373,6 +384,29 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield self.room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + ) + + @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( username, as_token diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4f0b43c36d..637640ec2a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -21,7 +21,7 @@ from synapse.events.utils import prune_event from synapse.util.async import ObservableDeferred from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, preserve_context_over_deferred + preserve_fn, PreserveLoggingContext, make_deferred_yieldable ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -88,13 +88,23 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + Args: room_id (str): events_and_contexts (list[(EventBase, EventContext)]): backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. end_item = queue[-1] if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) @@ -113,11 +123,11 @@ class _EventPeristenceQueue(object): def handle_queue(self, room_id, per_item_callback): """Attempts to handle the queue for a room if not already being handled. - The given callback will be invoked with for each item in the queue,1 + The given callback will be invoked with for each item in the queue, of type _EventPersistQueueItem. The per_item_callback will continuously be called with new items, unless the queue becomnes empty. The return value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferres as well. + exceptions will be passed to the deferreds as well. This function should therefore be called whenever anything is added to the queue. @@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore): deferreds = [] for room_id, evs_ctxs in partitioned.iteritems(): - d = preserve_fn(self._event_persist_queue.add_to_queue)( + d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled, ) @@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore): for room_id in partitioned: self._maybe_start_persisting(room_id) - return preserve_context_over_deferred( + return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) @@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield preserve_context_over_deferred(deferred) + yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -1526,7 +1536,7 @@ class EventsStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield preserve_context_over_deferred(defer.gatherResults( + res = yield make_deferred_yieldable(defer.gatherResults( [ preserve_fn(self._get_event_from_row)( row["internal_metadata"], row["json"], row["redacts"], diff --git a/synapse/util/async.py b/synapse/util/async.py index 0fd5b42523..a0a9039475 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -53,6 +53,11 @@ class ObservableDeferred(object): Cancelling or otherwise resolving an observer will not affect the original ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. """ __slots__ = ["_deferred", "_observers", "_result"] diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b6173ab2ee..821c735528 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True + self.hs.config.auto_join_rooms = [] # init the thing we're testing self.servlet = RegisterRestServlet(self.hs) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py deleted file mode 100644 index 024ac15069..0000000000 --- a/tests/storage/event_injector.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.api.constants import EventTypes - - -class EventInjector: - def __init__(self, hs): - self.hs = hs - self.store = hs.get_datastore() - self.message_handler = hs.get_handlers().message_handler - self.event_builder_factory = hs.get_event_builder_factory() - - @defer.inlineCallbacks - def create_room(self, room, user): - builder = self.event_builder_factory.new({ - "type": EventTypes.Create, - "sender": user.to_string(), - "room_id": room.to_string(), - "content": {}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) diff --git a/tests/util/test_log_context.py b/tests/util/test_logcontext.py index 9ffe209c4d..e2f7765f49 100644 --- a/tests/util/test_log_context.py +++ b/tests/util/test_logcontext.py @@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase): yield defer.succeed(None) return self._test_preserve_fn(nonblocking_function) + + @defer.inlineCallbacks + def test_make_deferred_yieldable(self): + # a function which retuns an incomplete deferred, but doesn't follow + # the synapse rules. + def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + return d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.test_key = "one" + + d1 = logcontext.make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + + @defer.inlineCallbacks + def test_make_deferred_yieldable_on_non_deferred(self): + """Check that make_deferred_yieldable does the right thing when its + argument isn't actually a deferred""" + + with LoggingContext() as context_one: + context_one.test_key = "one" + + d1 = logcontext.make_deferred_yieldable("bum") + self._check_test_key("one") + + r = yield d1 + self.assertEqual(r, "bum") + self._check_test_key("one") |