summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/config/registration.py7
-rw-r--r--synapse/rest/client/v2_alpha/register.py34
-rw-r--r--synapse/storage/events.py24
-rw-r--r--synapse/util/async.py5
-rw-r--r--tests/rest/client/v2_alpha/test_register.py1
-rw-r--r--tests/util/test_logcontext.py (renamed from tests/util/test_log_context.py)38
6 files changed, 102 insertions, 7 deletions
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index f7e03c4cde..ef917fc9f2 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -41,6 +41,8 @@ class RegistrationConfig(Config):
             self.allow_guest_access and config.get("invite_3pid_guest", False)
         )
 
+        self.auto_join_rooms = config.get("auto_join_rooms", [])
+
     def default_config(self, **kwargs):
         registration_shared_secret = random_string_with_symbols(50)
 
@@ -70,6 +72,11 @@ class RegistrationConfig(Config):
             - matrix.org
             - vector.im
             - riot.im
+
+        # Users who register on this homeserver will automatically be joined
+        # to these rooms
+        #auto_join_rooms:
+        #    - "#example:example.com"
         """ % locals()
 
     def add_arguments(self, parser):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 1421c18152..d9a8cdbbb5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,8 +17,10 @@
 from twisted.internet import defer
 
 import synapse
+import synapse.types
 from synapse.api.auth import get_access_token_from_request, has_access_token
 from synapse.api.constants import LoginType
+from synapse.types import RoomID, RoomAlias
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
 from synapse.http.servlet import (
     RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
@@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
+        self.room_member_handler = hs.get_handlers().room_member_handler
         self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
@@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
                 generate_token=False,
             )
 
+            # auto-join the user to any rooms we're supposed to dump them into
+            fake_requester = synapse.types.create_requester(registered_user_id)
+            for r in self.hs.config.auto_join_rooms:
+                try:
+                    yield self._join_user_to_room(fake_requester, r)
+                except Exception as e:
+                    logger.error("Failed to join new user to %r: %r", r, e)
+
             # remember that we've now registered that user account, and with
             #  what user ID (since the user may not have specified)
             self.auth_handler.set_session_data(
@@ -373,6 +384,29 @@ class RegisterRestServlet(RestServlet):
         return 200, {}
 
     @defer.inlineCallbacks
+    def _join_user_to_room(self, requester, room_identifier):
+        room_id = None
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, remote_room_hosts = (
+                yield self.room_member_handler.lookup_room_alias(room_alias)
+            )
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(400, "%s was not legal room ID or room alias" % (
+                room_identifier,
+            ))
+
+        yield self.room_member_handler.update_membership(
+            requester=requester,
+            target=requester.user,
+            room_id=room_id,
+            action="join",
+        )
+
+    @defer.inlineCallbacks
     def _do_appservice_registration(self, username, as_token, body):
         user_id = yield self.registration_handler.appservice_register(
             username, as_token
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4f0b43c36d..637640ec2a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
 
 from synapse.util.async import ObservableDeferred
 from synapse.util.logcontext import (
-    preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
+    preserve_fn, PreserveLoggingContext, make_deferred_yieldable
 )
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
@@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
 
+        NB: due to the normal usage pattern of this method, it does *not*
+        follow the synapse logcontext rules, and leaves the logcontext in
+        place whether or not the returned deferred is ready.
+
         Args:
             room_id (str):
             events_and_contexts (list[(EventBase, EventContext)]):
             backfilled (bool):
+
+        Returns:
+            defer.Deferred: a deferred which will resolve once the events are
+                persisted. Runs its callbacks *without* a logcontext.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
+            # if the last item in the queue has the same `backfilled` setting,
+            # we can just add these new events to that item.
             end_item = queue[-1]
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
@@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
     def handle_queue(self, room_id, per_item_callback):
         """Attempts to handle the queue for a room if not already being handled.
 
-        The given callback will be invoked with for each item in the queue,1
+        The given callback will be invoked with for each item in the queue,
         of type _EventPersistQueueItem. The per_item_callback will continuously
         be called with new items, unless the queue becomnes empty. The return
         value of the function will be given to the deferreds waiting on the item,
-        exceptions will be passed to the deferres as well.
+        exceptions will be passed to the deferreds as well.
 
         This function should therefore be called whenever anything is added
         to the queue.
@@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
 
         deferreds = []
         for room_id, evs_ctxs in partitioned.iteritems():
-            d = preserve_fn(self._event_persist_queue.add_to_queue)(
+            d = self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
@@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        return preserve_context_over_deferred(
+        return make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
@@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
 
         self._maybe_start_persisting(event.room_id)
 
-        yield preserve_context_over_deferred(deferred)
+        yield make_deferred_yieldable(deferred)
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
         defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -1526,7 +1536,7 @@ class EventsStore(SQLBaseStore):
         if not allow_rejected:
             rows[:] = [r for r in rows if not r["rejects"]]
 
-        res = yield preserve_context_over_deferred(defer.gatherResults(
+        res = yield make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self._get_event_from_row)(
                     row["internal_metadata"], row["json"], row["redacts"],
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0fd5b42523..a0a9039475 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -53,6 +53,11 @@ class ObservableDeferred(object):
 
     Cancelling or otherwise resolving an observer will not affect the original
     ObservableDeferred.
+
+    NB that it does not attempt to do anything with logcontexts; in general
+    you should probably make_deferred_yieldable the deferreds
+    returned by `observe`, and ensure that the original deferred runs its
+    callbacks in the sentinel logcontext.
     """
 
     __slots__ = ["_deferred", "_observers", "_result"]
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b6173ab2ee..821c735528 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
         self.hs.get_device_handler = Mock(return_value=self.device_handler)
         self.hs.config.enable_registration = True
+        self.hs.config.auto_join_rooms = []
 
         # init the thing we're testing
         self.servlet = RegisterRestServlet(self.hs)
diff --git a/tests/util/test_log_context.py b/tests/util/test_logcontext.py
index 9ffe209c4d..e2f7765f49 100644
--- a/tests/util/test_log_context.py
+++ b/tests/util/test_logcontext.py
@@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
                 yield defer.succeed(None)
 
         return self._test_preserve_fn(nonblocking_function)
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable(self):
+        # a function which retuns an incomplete deferred, but doesn't follow
+        # the synapse rules.
+        def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            return d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.test_key = "one"
+
+            d1 = logcontext.make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_on_non_deferred(self):
+        """Check that make_deferred_yieldable does the right thing when its
+        argument isn't actually a deferred"""
+
+        with LoggingContext() as context_one:
+            context_one.test_key = "one"
+
+            d1 = logcontext.make_deferred_yieldable("bum")
+            self._check_test_key("one")
+
+            r = yield d1
+            self.assertEqual(r, "bum")
+            self._check_test_key("one")