summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py9
-rw-r--r--tests/rest/client/test_transactions.py6
-rw-r--r--tests/rest/media/v1/test_media_storage.py5
-rw-r--r--tests/test_distributor.py2
-rw-r--r--tests/test_event_auth.py151
-rw-r--r--tests/test_state.py16
-rw-r--r--tests/util/caches/test_descriptors.py2
-rw-r--r--tests/util/test_file_consumer.py6
-rw-r--r--tests/util/test_linearizer.py7
-rw-r--r--tests/util/test_logcontext.py11
-rw-r--r--tests/util/test_stream_change_cache.py198
-rw-r--r--tests/utils.py7
12 files changed, 393 insertions, 27 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 149e443022..cc1c862ba4 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,10 +19,10 @@ import signedjson.sign
 from mock import Mock
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
-from synapse.util import async, logcontext
+from synapse.util import logcontext, Clock
 from synapse.util.logcontext import LoggingContext
 from tests import unittest, utils
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 
 
 class MockPerspectiveServer(object):
@@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
+        clock = Clock(reactor)
         key1 = signedjson.key.generate_signing_key(1)
 
         kr = keyring.Keyring(self.hs)
@@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
 
             # wait a tick for it to send the request to the perspectives server
             # (it first tries the datastore)
-            yield async.sleep(1)   # XXX find out why this takes so long!
+            yield clock.sleep(1)   # XXX find out why this takes so long!
             self.http_client.post_json.assert_called_once()
 
             self.assertIs(LoggingContext.current_context(), context_11)
@@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
                 res_deferreds_2 = kr.verify_json_objects_for_server(
                     [("server10", json1)],
                 )
-                yield async.sleep(1)
+                yield clock.sleep(1)
                 self.http_client.post_json.assert_not_called()
                 res_deferreds_2[0].addBoth(self.check_context, None)
 
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index b5bc2fa255..6a757289db 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,9 +1,9 @@
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from mock import Mock, call
 
-from synapse.util import async
+from synapse.util import Clock
 from synapse.util.logcontext import LoggingContext
 from tests import unittest
 from tests.utils import MockClock
@@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_logcontexts_with_async_result(self):
         @defer.inlineCallbacks
         def cb():
-            yield async.sleep(0)
+            yield Clock(reactor).sleep(0)
             defer.returnValue("yay")
 
         @defer.inlineCallbacks
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index eef38b6781..c5e2f5549a 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
         self.secondary_base_path = os.path.join(self.test_dir, "secondary")
 
         hs = Mock()
+        hs.get_reactor = Mock(return_value=reactor)
         hs.config.media_store_path = self.primary_base_path
 
         storage_providers = [FileStorageProviderBackend(
@@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
 
         self.filepaths = MediaFilePaths(self.primary_base_path)
         self.media_storage = MediaStorage(
-            self.primary_base_path, self.filepaths, storage_providers,
+            hs, self.primary_base_path, self.filepaths, storage_providers,
         )
 
     def tearDown(self):
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 010aeaee7e..c066381698 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
 from mock import Mock, patch
 
 from synapse.util.distributor import Distributor
-from synapse.util.async import run_on_reactor
 
 
 class DistributorTestCase(unittest.TestCase):
@@ -95,7 +94,6 @@ class DistributorTestCase(unittest.TestCase):
 
         @defer.inlineCallbacks
         def observer():
-            yield run_on_reactor()
             raise MyException("Oopsie")
 
         self.dist.observe("whail", observer)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
new file mode 100644
index 0000000000..d08e19c53a
--- /dev/null
+++ b/tests/test_event_auth.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse import event_auth
+from synapse.api.errors import AuthError
+from synapse.events import FrozenEvent
+import unittest
+
+
+class EventAuthTestCase(unittest.TestCase):
+    def test_random_users_cannot_send_state_before_first_pl(self):
+        """
+        Check that, before the first PL lands, the creator is the only user
+        that can send a state event.
+        """
+        creator = "@creator:example.com"
+        joiner = "@joiner:example.com"
+        auth_events = {
+            ("m.room.create", ""): _create_event(creator),
+            ("m.room.member", creator): _join_event(creator),
+            ("m.room.member", joiner): _join_event(joiner),
+        }
+
+        # creator should be able to send state
+        event_auth.check(
+            _random_state_event(creator), auth_events,
+            do_sig_check=False,
+        )
+
+        # joiner should not be able to send state
+        self.assertRaises(
+            AuthError,
+            event_auth.check,
+            _random_state_event(joiner),
+            auth_events,
+            do_sig_check=False,
+        ),
+
+    def test_state_default_level(self):
+        """
+        Check that users above the state_default level can send state and
+        those below cannot
+        """
+        creator = "@creator:example.com"
+        pleb = "@joiner:example.com"
+        king = "@joiner2:example.com"
+
+        auth_events = {
+            ("m.room.create", ""): _create_event(creator),
+            ("m.room.member", creator): _join_event(creator),
+            ("m.room.power_levels", ""): _power_levels_event(creator, {
+                "state_default": "30",
+                "users": {
+                    pleb: "29",
+                    king: "30",
+                },
+            }),
+            ("m.room.member", pleb): _join_event(pleb),
+            ("m.room.member", king): _join_event(king),
+        }
+
+        # pleb should not be able to send state
+        self.assertRaises(
+            AuthError,
+            event_auth.check,
+            _random_state_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        ),
+
+        # king should be able to send state
+        event_auth.check(
+            _random_state_event(king), auth_events,
+            do_sig_check=False,
+        )
+
+
+# helpers for making events
+
+TEST_ROOM_ID = "!test:room"
+
+
+def _create_event(user_id):
+    return FrozenEvent({
+        "room_id": TEST_ROOM_ID,
+        "event_id": _get_event_id(),
+        "type": "m.room.create",
+        "sender": user_id,
+        "content": {
+            "creator": user_id,
+        },
+    })
+
+
+def _join_event(user_id):
+    return FrozenEvent({
+        "room_id": TEST_ROOM_ID,
+        "event_id": _get_event_id(),
+        "type": "m.room.member",
+        "sender": user_id,
+        "state_key": user_id,
+        "content": {
+            "membership": "join",
+        },
+    })
+
+
+def _power_levels_event(sender, content):
+    return FrozenEvent({
+        "room_id": TEST_ROOM_ID,
+        "event_id": _get_event_id(),
+        "type": "m.room.power_levels",
+        "sender": sender,
+        "state_key": "",
+        "content": content,
+    })
+
+
+def _random_state_event(sender):
+    return FrozenEvent({
+        "room_id": TEST_ROOM_ID,
+        "event_id": _get_event_id(),
+        "type": "test.state",
+        "sender": sender,
+        "state_key": "",
+        "content": {
+            "membership": "join",
+        },
+    })
+
+
+event_count = 0
+
+
+def _get_event_id():
+    global event_count
+    c = event_count
+    event_count += 1
+    return "!%i:example.com" % (c, )
diff --git a/tests/test_state.py b/tests/test_state.py
index a5c5e55951..71c412faf4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
+        power_levels = create_event(
+            type=EventTypes.PowerLevels, state_key="",
+            content={"users": {
+                "@foo:bar": "100",
+                "@user_id:example.com": "100",
+            }}
+        )
+
         creation = create_event(
             type=EventTypes.Create, state_key="",
             content={"creator": "@foo:bar"}
@@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase):
 
         old_state_1 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         old_state_2 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
@@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         self.assertEqual(
-            old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+            old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
         )
 
         # Reverse the depth to make sure we are actually using the depths
@@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase):
 
         old_state_1 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         old_state_2 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
@@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         self.assertEqual(
-            old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+            old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
         )
 
     def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..24754591df 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -18,7 +18,6 @@ import logging
 
 import mock
 from synapse.api.errors import SynapseError
-from synapse.util import async
 from synapse.util import logcontext
 from twisted.internet import defer
 from synapse.util.caches import descriptors
@@ -195,7 +194,6 @@ class DescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 @defer.inlineCallbacks
                 def inner_fn():
-                    yield async.run_on_reactor()
                     raise SynapseError(400, "blah")
 
                 return inner_fn()
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index d6e1082779..c2aae8f54c 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
     def test_pull_consumer(self):
         string_file = StringIO()
-        consumer = BackgroundFileConsumer(string_file)
+        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
             producer = DummyPullProducer()
@@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
     def test_push_consumer(self):
         string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file)
+        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=[])
@@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
     def test_push_producer_feedback(self):
         string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file)
+        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..bf7e3aa885 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,10 +12,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util import async, logcontext
+
+from synapse.util import logcontext, Clock
 from tests import unittest
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 
 from synapse.util.async import Linearizer
 from six.moves import range
@@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
                     self.assertEqual(
                         logcontext.LoggingContext.current_context(), lc)
                     if sleep:
-                        yield async.sleep(0)
+                        yield Clock(reactor).sleep(0)
 
                 self.assertEqual(
                     logcontext.LoggingContext.current_context(), lc)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index ad78d884e0..9cf90fcfc4 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -3,8 +3,7 @@ from twisted.internet import defer
 from twisted.internet import reactor
 from .. import unittest
 
-from synapse.util.async import sleep
-from synapse.util import logcontext
+from synapse.util import logcontext, Clock
 from synapse.util.logcontext import LoggingContext
 
 
@@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_sleep(self):
+        clock = Clock(reactor)
+
         @defer.inlineCallbacks
         def competing_callback():
             with LoggingContext() as competing_context:
                 competing_context.request = "competing"
-                yield sleep(0)
+                yield clock.sleep(0)
                 self._check_test_key("competing")
 
         reactor.callLater(0, competing_callback)
 
         with LoggingContext() as context_one:
             context_one.request = "one"
-            yield sleep(0)
+            yield clock.sleep(0)
             self._check_test_key("one")
 
     def _test_run_in_background(self, function):
@@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
     def test_run_in_background_with_blocking_fn(self):
         @defer.inlineCallbacks
         def blocking_function():
-            yield sleep(0)
+            yield Clock(reactor).sleep(0)
 
         return self._test_run_in_background(blocking_function)
 
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
new file mode 100644
index 0000000000..67ece166c7
--- /dev/null
+++ b/tests/util/test_stream_change_cache.py
@@ -0,0 +1,198 @@
+from tests import unittest
+from mock import patch
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class StreamChangeCacheTests(unittest.TestCase):
+    """
+    Tests for StreamChangeCache.
+    """
+
+    def test_prefilled_cache(self):
+        """
+        Providing a prefilled cache to StreamChangeCache will result in a cache
+        with the prefilled-cache entered in.
+        """
+        cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
+
+    def test_has_entity_changed(self):
+        """
+        StreamChangeCache.entity_has_changed will mark entities as changed, and
+        has_entity_changed will observe the changed entities.
+        """
+        cache = StreamChangeCache("#test", 3)
+
+        cache.entity_has_changed("user@foo.com", 6)
+        cache.entity_has_changed("bar@baz.net", 7)
+
+        # If it's been changed after that stream position, return True
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+        self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+
+        # If it's been changed at that stream position, return False
+        self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+
+        # If there's no changes after that stream position, return False
+        self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+
+        # If the entity does not exist, return False.
+        self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+
+        # If we request before the stream cache's earliest known position,
+        # return True, whether it's a known entity or not.
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
+        self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+
+    @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
+    def test_has_entity_changed_pops_off_start(self):
+        """
+        StreamChangeCache.entity_has_changed will respect the max size and
+        purge the oldest items upon reaching that max size.
+        """
+        cache = StreamChangeCache("#test", 1, max_size=2)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # The cache is at the max size, 2
+        self.assertEqual(len(cache._cache), 2)
+
+        # The oldest item has been popped off
+        self.assertTrue("user@foo.com" not in cache._entity_to_key)
+
+        # If we update an existing entity, it keeps the two existing entities
+        cache.entity_has_changed("bar@baz.net", 5)
+        self.assertEqual(
+            set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+        )
+
+    def test_get_all_entities_changed(self):
+        """
+        StreamChangeCache.get_all_entities_changed will return all changed
+        entities since the given position.  If the position is before the start
+        of the known stream, it returns None instead.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        self.assertEqual(
+            cache.get_all_entities_changed(1),
+            ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+        )
+        self.assertEqual(
+            cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
+        )
+        self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
+        self.assertEqual(cache.get_all_entities_changed(0), None)
+
+    def test_has_any_entity_changed(self):
+        """
+        StreamChangeCache.has_any_entity_changed will return True if any
+        entities have been changed since the provided stream position, and
+        False if they have not.  If the cache has entries and the provided
+        stream position is before it, it will return True, otherwise False if
+        the cache has no entries.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        # With no entities, it returns False for the past, present, and future.
+        self.assertFalse(cache.has_any_entity_changed(0))
+        self.assertFalse(cache.has_any_entity_changed(1))
+        self.assertFalse(cache.has_any_entity_changed(2))
+
+        # We add an entity
+        cache.entity_has_changed("user@foo.com", 2)
+
+        # With an entity, it returns True for the past, the stream start
+        # position, and False for the stream position the entity was changed
+        # on and ones after it.
+        self.assertTrue(cache.has_any_entity_changed(0))
+        self.assertTrue(cache.has_any_entity_changed(1))
+        self.assertFalse(cache.has_any_entity_changed(2))
+        self.assertFalse(cache.has_any_entity_changed(3))
+
+    def test_get_entities_changed(self):
+        """
+        StreamChangeCache.get_entities_changed will return the entities in the
+        given list that have changed since the provided stream ID.  If the
+        stream position is earlier than the earliest known position, it will
+        return all of the entities queried for.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # Query all the entries, but mid-way through the stream. We should only
+        # get the ones after that point.
+        self.assertEqual(
+            cache.get_entities_changed(
+                ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+            ),
+            set(["bar@baz.net", "user@elsewhere.org"]),
+        )
+
+        # Query all the entries mid-way through the stream, but include one
+        # that doesn't exist in it. We should get back the one that doesn't
+        # exist, too.
+        self.assertEqual(
+            cache.get_entities_changed(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ],
+                stream_pos=2,
+            ),
+            set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+        )
+
+        # Query all the entries, but before the first known point. We will get
+        # all the entries we queried for, including ones that don't exist.
+        self.assertEqual(
+            cache.get_entities_changed(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ],
+                stream_pos=0,
+            ),
+            set(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ]
+            ),
+        )
+
+    def test_max_pos(self):
+        """
+        StreamChangeCache.get_max_pos_of_last_change will return the most
+        recent point where the entity could have changed.  If the entity is not
+        known, the stream start is provided instead.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # Known entities will return the point where they were changed.
+        self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2)
+        self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
+        self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
+
+        # Unknown entities will return the stream start position.
+        self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
diff --git a/tests/utils.py b/tests/utils.py
index 262c4a5714..189fd2711c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
 
 
 @defer.inlineCallbacks
-def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
+def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
+                          **kargs):
     """Setup a homeserver suitable for running tests against. Keyword arguments
     are passed to the Homeserver constructor. If no datastore is supplied a
     datastore backed by an in-memory sqlite db will be given to the HS.
     """
+    if reactor is None:
+        from twisted.internet import reactor
+
     if config is None:
         config = Mock()
         config.signing_key = [MockKey()]
@@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
             database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
+            reactor=reactor,
             **kargs
         )
         db_conn = hs.get_db_conn()