summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/config/_base.py36
-rw-r--r--synapse/config/key.py8
-rw-r--r--synapse/config/registration.py7
-rw-r--r--synapse/config/tls.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py34
-rw-r--r--synapse/storage/events.py24
-rw-r--r--synapse/util/async.py5
-rw-r--r--tests/rest/client/v2_alpha/test_register.py1
-rw-r--r--tests/storage/event_injector.py76
-rw-r--r--tests/util/test_logcontext.py (renamed from tests/util/test_log_context.py)38
10 files changed, 135 insertions, 100 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1ab5593c6e..fa105bce72 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -82,21 +82,37 @@ class Config(object):
         return os.path.abspath(file_path) if file_path else file_path
 
     @classmethod
+    def path_exists(cls, file_path):
+        """Check if a file exists
+
+        Unlike os.path.exists, this throws an exception if there is an error
+        checking if the file exists (for example, if there is a perms error on
+        the parent dir).
+
+        Returns:
+            bool: True if the file exists; False if not.
+        """
+        try:
+            os.stat(file_path)
+            return True
+        except OSError as e:
+            if e.errno != errno.ENOENT:
+                raise e
+            return False
+
+    @classmethod
     def check_file(cls, file_path, config_name):
         if file_path is None:
             raise ConfigError(
                 "Missing config for %s."
-                " You must specify a path for the config file. You can "
-                "do this with the -c or --config-path option. "
-                "Adding --generate-config along with --server-name "
-                "<server name> will generate a config file at the given path."
                 % (config_name,)
             )
-        if not os.path.exists(file_path):
+        try:
+            os.stat(file_path)
+        except OSError as e:
             raise ConfigError(
-                "File %s config for %s doesn't exist."
-                " Try running again with --generate-config"
-                % (file_path, config_name,)
+                "Error accessing file '%s' (config for %s): %s"
+                % (file_path, config_name, e.strerror)
             )
         return cls.abspath(file_path)
 
@@ -248,7 +264,7 @@ class Config(object):
                     " -c CONFIG-FILE\""
                 )
             (config_path,) = config_files
-            if not os.path.exists(config_path):
+            if not cls.path_exists(config_path):
                 if config_args.keys_directory:
                     config_dir_path = config_args.keys_directory
                 else:
@@ -261,7 +277,7 @@ class Config(object):
                         "Must specify a server_name to a generate config for."
                         " Pass -H server.name."
                     )
-                if not os.path.exists(config_dir_path):
+                if not cls.path_exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "wb") as config_file:
                     config_bytes, config = obj.generate_config(
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 6ee643793e..4b8fc063d0 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -118,10 +118,9 @@ class KeyConfig(Config):
         signing_keys = self.read_file(signing_key_path, "signing_key")
         try:
             return read_signing_keys(signing_keys.splitlines(True))
-        except Exception:
+        except Exception as e:
             raise ConfigError(
-                "Error reading signing_key."
-                " Try running again with --generate-config"
+                "Error reading signing_key: %s" % (str(e))
             )
 
     def read_old_signing_keys(self, old_signing_keys):
@@ -141,7 +140,8 @@ class KeyConfig(Config):
 
     def generate_files(self, config):
         signing_key_path = config["signing_key_path"]
-        if not os.path.exists(signing_key_path):
+
+        if not self.path_exists(signing_key_path):
             with open(signing_key_path, "w") as signing_key_file:
                 key_id = "a_" + random_string(4)
                 write_signing_keys(
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index f7e03c4cde..ef917fc9f2 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -41,6 +41,8 @@ class RegistrationConfig(Config):
             self.allow_guest_access and config.get("invite_3pid_guest", False)
         )
 
+        self.auto_join_rooms = config.get("auto_join_rooms", [])
+
     def default_config(self, **kwargs):
         registration_shared_secret = random_string_with_symbols(50)
 
@@ -70,6 +72,11 @@ class RegistrationConfig(Config):
             - matrix.org
             - vector.im
             - riot.im
+
+        # Users who register on this homeserver will automatically be joined
+        # to these rooms
+        #auto_join_rooms:
+        #    - "#example:example.com"
         """ % locals()
 
     def add_arguments(self, parser):
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index e081840a83..247f18f454 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -126,7 +126,7 @@ class TlsConfig(Config):
         tls_private_key_path = config["tls_private_key_path"]
         tls_dh_params_path = config["tls_dh_params_path"]
 
-        if not os.path.exists(tls_private_key_path):
+        if not self.path_exists(tls_private_key_path):
             with open(tls_private_key_path, "w") as private_key_file:
                 tls_private_key = crypto.PKey()
                 tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
@@ -141,7 +141,7 @@ class TlsConfig(Config):
                     crypto.FILETYPE_PEM, private_key_pem
                 )
 
-        if not os.path.exists(tls_certificate_path):
+        if not self.path_exists(tls_certificate_path):
             with open(tls_certificate_path, "w") as certificate_file:
                 cert = crypto.X509()
                 subject = cert.get_subject()
@@ -159,7 +159,7 @@ class TlsConfig(Config):
 
                 certificate_file.write(cert_pem)
 
-        if not os.path.exists(tls_dh_params_path):
+        if not self.path_exists(tls_dh_params_path):
             if GENERATE_DH_PARAMS:
                 subprocess.check_call([
                     "openssl", "dhparam",
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 1421c18152..d9a8cdbbb5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,8 +17,10 @@
 from twisted.internet import defer
 
 import synapse
+import synapse.types
 from synapse.api.auth import get_access_token_from_request, has_access_token
 from synapse.api.constants import LoginType
+from synapse.types import RoomID, RoomAlias
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
 from synapse.http.servlet import (
     RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
@@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
+        self.room_member_handler = hs.get_handlers().room_member_handler
         self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
@@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
                 generate_token=False,
             )
 
+            # auto-join the user to any rooms we're supposed to dump them into
+            fake_requester = synapse.types.create_requester(registered_user_id)
+            for r in self.hs.config.auto_join_rooms:
+                try:
+                    yield self._join_user_to_room(fake_requester, r)
+                except Exception as e:
+                    logger.error("Failed to join new user to %r: %r", r, e)
+
             # remember that we've now registered that user account, and with
             #  what user ID (since the user may not have specified)
             self.auth_handler.set_session_data(
@@ -373,6 +384,29 @@ class RegisterRestServlet(RestServlet):
         return 200, {}
 
     @defer.inlineCallbacks
+    def _join_user_to_room(self, requester, room_identifier):
+        room_id = None
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, remote_room_hosts = (
+                yield self.room_member_handler.lookup_room_alias(room_alias)
+            )
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(400, "%s was not legal room ID or room alias" % (
+                room_identifier,
+            ))
+
+        yield self.room_member_handler.update_membership(
+            requester=requester,
+            target=requester.user,
+            room_id=room_id,
+            action="join",
+        )
+
+    @defer.inlineCallbacks
     def _do_appservice_registration(self, username, as_token, body):
         user_id = yield self.registration_handler.appservice_register(
             username, as_token
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4f0b43c36d..637640ec2a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
 
 from synapse.util.async import ObservableDeferred
 from synapse.util.logcontext import (
-    preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
+    preserve_fn, PreserveLoggingContext, make_deferred_yieldable
 )
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
@@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
 
+        NB: due to the normal usage pattern of this method, it does *not*
+        follow the synapse logcontext rules, and leaves the logcontext in
+        place whether or not the returned deferred is ready.
+
         Args:
             room_id (str):
             events_and_contexts (list[(EventBase, EventContext)]):
             backfilled (bool):
+
+        Returns:
+            defer.Deferred: a deferred which will resolve once the events are
+                persisted. Runs its callbacks *without* a logcontext.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
+            # if the last item in the queue has the same `backfilled` setting,
+            # we can just add these new events to that item.
             end_item = queue[-1]
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
@@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
     def handle_queue(self, room_id, per_item_callback):
         """Attempts to handle the queue for a room if not already being handled.
 
-        The given callback will be invoked with for each item in the queue,1
+        The given callback will be invoked with for each item in the queue,
         of type _EventPersistQueueItem. The per_item_callback will continuously
         be called with new items, unless the queue becomnes empty. The return
         value of the function will be given to the deferreds waiting on the item,
-        exceptions will be passed to the deferres as well.
+        exceptions will be passed to the deferreds as well.
 
         This function should therefore be called whenever anything is added
         to the queue.
@@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
 
         deferreds = []
         for room_id, evs_ctxs in partitioned.iteritems():
-            d = preserve_fn(self._event_persist_queue.add_to_queue)(
+            d = self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
@@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        return preserve_context_over_deferred(
+        return make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
@@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
 
         self._maybe_start_persisting(event.room_id)
 
-        yield preserve_context_over_deferred(deferred)
+        yield make_deferred_yieldable(deferred)
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
         defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -1526,7 +1536,7 @@ class EventsStore(SQLBaseStore):
         if not allow_rejected:
             rows[:] = [r for r in rows if not r["rejects"]]
 
-        res = yield preserve_context_over_deferred(defer.gatherResults(
+        res = yield make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self._get_event_from_row)(
                     row["internal_metadata"], row["json"], row["redacts"],
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0fd5b42523..a0a9039475 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -53,6 +53,11 @@ class ObservableDeferred(object):
 
     Cancelling or otherwise resolving an observer will not affect the original
     ObservableDeferred.
+
+    NB that it does not attempt to do anything with logcontexts; in general
+    you should probably make_deferred_yieldable the deferreds
+    returned by `observe`, and ensure that the original deferred runs its
+    callbacks in the sentinel logcontext.
     """
 
     __slots__ = ["_deferred", "_observers", "_result"]
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b6173ab2ee..821c735528 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
         self.hs.get_device_handler = Mock(return_value=self.device_handler)
         self.hs.config.enable_registration = True
+        self.hs.config.auto_join_rooms = []
 
         # init the thing we're testing
         self.servlet = RegisterRestServlet(self.hs)
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
deleted file mode 100644
index 024ac15069..0000000000
--- a/tests/storage/event_injector.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-
-
-class EventInjector:
-    def __init__(self, hs):
-        self.hs = hs
-        self.store = hs.get_datastore()
-        self.message_handler = hs.get_handlers().message_handler
-        self.event_builder_factory = hs.get_event_builder_factory()
-
-    @defer.inlineCallbacks
-    def create_room(self, room, user):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Create,
-            "sender": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-    @defer.inlineCallbacks
-    def inject_room_member(self, room, user, membership):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Member,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"membership": membership},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-        defer.returnValue(event)
-
-    @defer.inlineCallbacks
-    def inject_message(self, room, user, body):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Message,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"body": body, "msgtype": u"message"},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
diff --git a/tests/util/test_log_context.py b/tests/util/test_logcontext.py
index 9ffe209c4d..e2f7765f49 100644
--- a/tests/util/test_log_context.py
+++ b/tests/util/test_logcontext.py
@@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
                 yield defer.succeed(None)
 
         return self._test_preserve_fn(nonblocking_function)
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable(self):
+        # a function which retuns an incomplete deferred, but doesn't follow
+        # the synapse rules.
+        def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            return d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.test_key = "one"
+
+            d1 = logcontext.make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_on_non_deferred(self):
+        """Check that make_deferred_yieldable does the right thing when its
+        argument isn't actually a deferred"""
+
+        with LoggingContext() as context_one:
+            context_one.test_key = "one"
+
+            d1 = logcontext.make_deferred_yieldable("bum")
+            self._check_test_key("one")
+
+            r = yield d1
+            self.assertEqual(r, "bum")
+            self._check_test_key("one")