summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_ratelimiting.py20
-rw-r--r--tests/handlers/test_profile.py4
-rw-r--r--tests/handlers/test_typing.py290
-rw-r--r--tests/replication/slave/storage/_base.py4
-rw-r--r--tests/rest/client/v1/test_admin.py38
-rw-r--r--tests/rest/client/v1/test_events.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py48
-rw-r--r--tests/server.py21
-rw-r--r--tests/test_mau.py3
-rw-r--r--tests/unittest.py8
-rw-r--r--tests/utils.py37
13 files changed, 283 insertions, 204 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 8933fe3b72..30a255d441 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -6,34 +6,34 @@ from tests import unittest
 class TestRatelimiter(unittest.TestCase):
     def test_allowed(self):
         limiter = Ratelimiter()
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
         )
         self.assertTrue(allowed)
         self.assertEquals(10., time_allowed)
 
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
         )
         self.assertFalse(allowed)
         self.assertEquals(10., time_allowed)
 
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
         )
         self.assertTrue(allowed)
         self.assertEquals(20., time_allowed)
 
     def test_pruning(self):
         limiter = Ratelimiter()
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
         )
 
         self.assertIn("test_id_1", limiter.message_counts)
 
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
         )
 
         self.assertNotIn("test_id_1", limiter.message_counts)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 80da1c8954..d60c124eec 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase):
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         self.store = hs.get_datastore()
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 36e136cded..13486930fb 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,13 +24,17 @@ from synapse.api.errors import AuthError
 from synapse.types import UserID
 
 from tests import unittest
+from tests.utils import register_federation_servlets
 
-from ..utils import (
-    DeferredMockCallable,
-    MockClock,
-    MockHttpResource,
-    setup_test_homeserver,
-)
+# Some local users to test with
+U_APPLE = UserID.from_string("@apple:test")
+U_BANANA = UserID.from_string("@banana:test")
+
+# Remote user
+U_ONION = UserID.from_string("@onion:farm")
+
+# Test room id
+ROOM_ID = "a-room"
 
 
 def _expect_edu_transaction(edu_type, content, origin="test"):
@@ -46,30 +50,21 @@ def _make_edu_transaction_json(edu_type, content):
     return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
 
 
-class TypingNotificationsTestCase(unittest.TestCase):
-    """Tests typing notifications to rooms."""
-
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.clock = MockClock()
+class TypingNotificationsTestCase(unittest.HomeserverTestCase):
+    servlets = [register_federation_servlets]
 
-        self.mock_http_client = Mock(spec=[])
-        self.mock_http_client.put_json = DeferredMockCallable()
+    def make_homeserver(self, reactor, clock):
+        # we mock out the keyring so as to skip the authentication check on the
+        # federation API call.
+        mock_keyring = Mock(spec=["verify_json_for_server"])
+        mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
 
-        self.mock_federation_resource = MockHttpResource()
-
-        mock_notifier = Mock()
-        self.on_new_event = mock_notifier.on_new_event
+        # we mock out the federation client too
+        mock_federation_client = Mock(spec=["put_json"])
+        mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
 
-        self.auth = Mock(spec=[])
-        self.state_handler = Mock()
-
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
-            "test",
-            auth=self.auth,
-            clock=self.clock,
-            datastore=Mock(
+        hs = self.setup_test_homeserver(
+            datastore=(Mock(
                 spec=[
                     # Bits that Federation needs
                     "prep_send_transaction",
@@ -82,16 +77,21 @@ class TypingNotificationsTestCase(unittest.TestCase):
                     "get_user_directory_stream_pos",
                     "get_current_state_deltas",
                 ]
-            ),
-            state_handler=self.state_handler,
-            handlers=Mock(),
-            notifier=mock_notifier,
-            resource_for_client=Mock(),
-            resource_for_federation=self.mock_federation_resource,
-            http_client=self.mock_http_client,
-            keyring=Mock(),
+            )),
+            notifier=Mock(),
+            http_client=mock_federation_client,
+            keyring=mock_keyring,
         )
 
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        # the tests assume that we are starting at unix time 1000
+        reactor.pump((1000, ))
+
+        mock_notifier = hs.get_notifier()
+        self.on_new_event = mock_notifier.on_new_event
+
         self.handler = hs.get_typing_handler()
 
         self.event_source = hs.get_event_sources().sources["typing"]
@@ -109,13 +109,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.datastore.get_received_txn_response = get_received_txn_response
 
-        self.room_id = "a-room"
-
         self.room_members = []
 
         def check_joined_room(room_id, user_id):
             if user_id not in [u.to_string() for u in self.room_members]:
                 raise AuthError(401, "User is not in the room")
+        hs.get_auth().check_joined_room = check_joined_room
 
         def get_joined_hosts_for_room(room_id):
             return set(member.domain for member in self.room_members)
@@ -124,8 +123,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         def get_current_user_in_room(room_id):
             return set(str(u) for u in self.room_members)
-
-        self.state_handler.get_current_user_in_room = get_current_user_in_room
+        hs.get_state_handler().get_current_user_in_room = get_current_user_in_room
 
         self.datastore.get_user_directory_stream_pos.return_value = (
             # we deliberately return a non-None stream pos to avoid doing an initial_spam
@@ -134,230 +132,208 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.datastore.get_current_state_deltas.return_value = None
 
-        self.auth.check_joined_room = check_joined_room
-
         self.datastore.get_to_device_stream_token = lambda: 0
         self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
         self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
 
-        # Some local users to test with
-        self.u_apple = UserID.from_string("@apple:test")
-        self.u_banana = UserID.from_string("@banana:test")
-
-        # Remote user
-        self.u_onion = UserID.from_string("@onion:farm")
-
-    @defer.inlineCallbacks
     def test_started_typing_local(self):
-        self.room_members = [self.u_apple, self.u_banana]
+        self.room_members = [U_APPLE, U_BANANA]
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        yield self.handler.started_typing(
-            target_user=self.u_apple,
-            auth_user=self.u_apple,
-            room_id=self.room_id,
+        self.successResultOf(self.handler.started_typing(
+            target_user=U_APPLE,
+            auth_user=U_APPLE,
+            room_id=ROOM_ID,
             timeout=20000,
-        )
+        ))
 
         self.on_new_event.assert_has_calls(
-            [call('typing_key', 1, rooms=[self.room_id])]
+            [call('typing_key', 1, rooms=[ROOM_ID])]
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events(
-            room_ids=[self.room_id], from_key=0
+        events = self.event_source.get_new_events(
+            room_ids=[ROOM_ID], from_key=0
         )
         self.assertEquals(
             events[0],
             [
                 {
                     "type": "m.typing",
-                    "room_id": self.room_id,
-                    "content": {"user_ids": [self.u_apple.to_string()]},
+                    "room_id": ROOM_ID,
+                    "content": {"user_ids": [U_APPLE.to_string()]},
                 }
             ],
         )
 
-    @defer.inlineCallbacks
     def test_started_typing_remote_send(self):
-        self.room_members = [self.u_apple, self.u_onion]
-
-        put_json = self.mock_http_client.put_json
-        put_json.expect_call_and_return(
-            call(
-                "farm",
-                path="/_matrix/federation/v1/send/1000000/",
-                data=_expect_edu_transaction(
-                    "m.typing",
-                    content={
-                        "room_id": self.room_id,
-                        "user_id": self.u_apple.to_string(),
-                        "typing": True,
-                    },
-                ),
-                json_data_callback=ANY,
-                long_retries=True,
-                backoff_on_404=True,
-            ),
-            defer.succeed((200, "OK")),
-        )
+        self.room_members = [U_APPLE, U_ONION]
 
-        yield self.handler.started_typing(
-            target_user=self.u_apple,
-            auth_user=self.u_apple,
-            room_id=self.room_id,
+        self.successResultOf(self.handler.started_typing(
+            target_user=U_APPLE,
+            auth_user=U_APPLE,
+            room_id=ROOM_ID,
             timeout=20000,
-        )
+        ))
 
-        yield put_json.await_calls()
+        put_json = self.hs.get_http_client().put_json
+        put_json.assert_called_once_with(
+            "farm",
+            path="/_matrix/federation/v1/send/1000000/",
+            data=_expect_edu_transaction(
+                "m.typing",
+                content={
+                    "room_id": ROOM_ID,
+                    "user_id": U_APPLE.to_string(),
+                    "typing": True,
+                },
+            ),
+            json_data_callback=ANY,
+            long_retries=True,
+            backoff_on_404=True,
+        )
 
-    @defer.inlineCallbacks
     def test_started_typing_remote_recv(self):
-        self.room_members = [self.u_apple, self.u_onion]
+        self.room_members = [U_APPLE, U_ONION]
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        (code, response) = yield self.mock_federation_resource.trigger(
+        (request, channel) = self.make_request(
             "PUT",
             "/_matrix/federation/v1/send/1000000/",
             _make_edu_transaction_json(
                 "m.typing",
                 content={
-                    "room_id": self.room_id,
-                    "user_id": self.u_onion.to_string(),
+                    "room_id": ROOM_ID,
+                    "user_id": U_ONION.to_string(),
                     "typing": True,
                 },
             ),
             federation_auth_origin=b'farm',
         )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
 
         self.on_new_event.assert_has_calls(
-            [call('typing_key', 1, rooms=[self.room_id])]
+            [call('typing_key', 1, rooms=[ROOM_ID])]
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events(
-            room_ids=[self.room_id], from_key=0
+        events = self.event_source.get_new_events(
+            room_ids=[ROOM_ID], from_key=0
         )
         self.assertEquals(
             events[0],
             [
                 {
                     "type": "m.typing",
-                    "room_id": self.room_id,
-                    "content": {"user_ids": [self.u_onion.to_string()]},
+                    "room_id": ROOM_ID,
+                    "content": {"user_ids": [U_ONION.to_string()]},
                 }
             ],
         )
 
-    @defer.inlineCallbacks
     def test_stopped_typing(self):
-        self.room_members = [self.u_apple, self.u_banana, self.u_onion]
-
-        put_json = self.mock_http_client.put_json
-        put_json.expect_call_and_return(
-            call(
-                "farm",
-                path="/_matrix/federation/v1/send/1000000/",
-                data=_expect_edu_transaction(
-                    "m.typing",
-                    content={
-                        "room_id": self.room_id,
-                        "user_id": self.u_apple.to_string(),
-                        "typing": False,
-                    },
-                ),
-                json_data_callback=ANY,
-                long_retries=True,
-                backoff_on_404=True,
-            ),
-            defer.succeed((200, "OK")),
-        )
+        self.room_members = [U_APPLE, U_BANANA, U_ONION]
 
         # Gut-wrenching
         from synapse.handlers.typing import RoomMember
 
-        member = RoomMember(self.room_id, self.u_apple.to_string())
+        member = RoomMember(ROOM_ID, U_APPLE.to_string())
         self.handler._member_typing_until[member] = 1002000
-        self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
+        self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        yield self.handler.stopped_typing(
-            target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id
-        )
+        self.successResultOf(self.handler.stopped_typing(
+            target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+        ))
 
         self.on_new_event.assert_has_calls(
-            [call('typing_key', 1, rooms=[self.room_id])]
+            [call('typing_key', 1, rooms=[ROOM_ID])]
         )
 
-        yield put_json.await_calls()
+        put_json = self.hs.get_http_client().put_json
+        put_json.assert_called_once_with(
+            "farm",
+            path="/_matrix/federation/v1/send/1000000/",
+            data=_expect_edu_transaction(
+                "m.typing",
+                content={
+                    "room_id": ROOM_ID,
+                    "user_id": U_APPLE.to_string(),
+                    "typing": False,
+                },
+            ),
+            json_data_callback=ANY,
+            long_retries=True,
+            backoff_on_404=True,
+        )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events(
-            room_ids=[self.room_id], from_key=0
+        events = self.event_source.get_new_events(
+            room_ids=[ROOM_ID], from_key=0
         )
         self.assertEquals(
             events[0],
             [
                 {
                     "type": "m.typing",
-                    "room_id": self.room_id,
+                    "room_id": ROOM_ID,
                     "content": {"user_ids": []},
                 }
             ],
         )
 
-    @defer.inlineCallbacks
     def test_typing_timeout(self):
-        self.room_members = [self.u_apple, self.u_banana]
+        self.room_members = [U_APPLE, U_BANANA]
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        yield self.handler.started_typing(
-            target_user=self.u_apple,
-            auth_user=self.u_apple,
-            room_id=self.room_id,
+        self.successResultOf(self.handler.started_typing(
+            target_user=U_APPLE,
+            auth_user=U_APPLE,
+            room_id=ROOM_ID,
             timeout=10000,
-        )
+        ))
 
         self.on_new_event.assert_has_calls(
-            [call('typing_key', 1, rooms=[self.room_id])]
+            [call('typing_key', 1, rooms=[ROOM_ID])]
         )
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events(
-            room_ids=[self.room_id], from_key=0
+        events = self.event_source.get_new_events(
+            room_ids=[ROOM_ID], from_key=0
         )
         self.assertEquals(
             events[0],
             [
                 {
                     "type": "m.typing",
-                    "room_id": self.room_id,
-                    "content": {"user_ids": [self.u_apple.to_string()]},
+                    "room_id": ROOM_ID,
+                    "content": {"user_ids": [U_APPLE.to_string()]},
                 }
             ],
         )
 
-        self.clock.advance_time(16)
+        self.reactor.pump([16, ])
 
         self.on_new_event.assert_has_calls(
-            [call('typing_key', 2, rooms=[self.room_id])]
+            [call('typing_key', 2, rooms=[ROOM_ID])]
         )
 
         self.assertEquals(self.event_source.get_current_key(), 2)
-        events = yield self.event_source.get_new_events(
-            room_ids=[self.room_id], from_key=1
+        events = self.event_source.get_new_events(
+            room_ids=[ROOM_ID], from_key=1
         )
         self.assertEquals(
             events[0],
             [
                 {
                     "type": "m.typing",
-                    "room_id": self.room_id,
+                    "room_id": ROOM_ID,
                     "content": {"user_ids": []},
                 }
             ],
@@ -365,29 +341,29 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         # SYN-230 - see if we can still set after timeout
 
-        yield self.handler.started_typing(
-            target_user=self.u_apple,
-            auth_user=self.u_apple,
-            room_id=self.room_id,
+        self.successResultOf(self.handler.started_typing(
+            target_user=U_APPLE,
+            auth_user=U_APPLE,
+            room_id=ROOM_ID,
             timeout=10000,
-        )
+        ))
 
         self.on_new_event.assert_has_calls(
-            [call('typing_key', 3, rooms=[self.room_id])]
+            [call('typing_key', 3, rooms=[ROOM_ID])]
         )
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
-        events = yield self.event_source.get_new_events(
-            room_ids=[self.room_id], from_key=0
+        events = self.event_source.get_new_events(
+            room_ids=[ROOM_ID], from_key=0
         )
         self.assertEquals(
             events[0],
             [
                 {
                     "type": "m.typing",
-                    "room_id": self.room_id,
-                    "content": {"user_ids": [self.u_apple.to_string()]},
+                    "room_id": ROOM_ID,
+                    "content": {"user_ids": [U_APPLE.to_string()]},
                 }
             ],
         )
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 9e9fbbfe93..524af4f8d1 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(
             "blue",
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
-        hs.get_ratelimiter().send_message.return_value = (True, 0)
+        hs.get_ratelimiter().can_do_action.return_value = (True, 0)
 
         return hs
 
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 407bf0ac4c..ea03b7e523 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -20,14 +20,48 @@ import json
 from mock import Mock
 
 from synapse.api.constants import UserTypes
-from synapse.rest.client.v1.admin import register_servlets
+from synapse.rest.client.v1 import admin, login
 
 from tests import unittest
 
 
+class VersionTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    url = '/_matrix/client/r0/admin/server_version'
+
+    def test_version_string(self):
+        self.register_user("admin", "pass", admin=True)
+        self.admin_token = self.login("admin", "pass")
+
+        request, channel = self.make_request("GET", self.url,
+                                             access_token=self.admin_token)
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]),
+                         msg=channel.result["body"])
+        self.assertEqual({'server_version', 'python_version'},
+                         set(channel.json_body.keys()))
+
+    def test_inaccessible_to_non_admins(self):
+        self.register_user("unprivileged-user", "pass", admin=False)
+        user_token = self.login("unprivileged-user", "pass")
+
+        request, channel = self.make_request("GET", self.url,
+                                             access_token=user_token)
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result['code']),
+                         msg=channel.result['body'])
+
+
 class UserRegisterTestCase(unittest.HomeserverTestCase):
 
-    servlets = [register_servlets]
+    servlets = [admin.register_servlets]
 
     def make_homeserver(self, reactor, clock):
 
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 483bebc832..36d8547275 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         config.auto_join_rooms = []
 
         hs = self.setup_test_homeserver(
-            config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
+            config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
         )
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         hs.get_handlers().federation_handler = Mock()
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index a824be9a62..015c144248 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
         self.ratelimiter = self.hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         self.hs.get_federation_handler = Mock(return_value=Mock())
 
@@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
         # auth as user_id now
         self.helper.auth_user_id = self.user_id
 
-    def test_send_message(self):
+    def test_can_do_action(self):
         msg_content = b'{"msgtype":"m.text","body":"hello"}'
 
         seq = iter(range(100))
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 0ad814c5e5..30fb77bac8 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
 
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         hs.get_handlers().federation_handler = Mock()
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 906b348d3e..3600434858 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+
+    def test_POST_ratelimiting_guest(self):
+        self.hs.config.rc_registration_request_burst_count = 5
+
+        for i in range(0, 6):
+            url = self.url + b"?kind=guest"
+            request, channel = self.make_request(b"POST", url, b"{}")
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_POST_ratelimiting(self):
+        self.hs.config.rc_registration_request_burst_count = 5
+
+        for i in range(0, 6):
+            params = {
+                "username": "kermit" + str(i),
+                "password": "monkey",
+                "device_id": "frogfone",
+                "auth": {"type": LoginType.DUMMY},
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", self.url, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/server.py b/tests/server.py
index fc1e76d146..37069afdda 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -137,6 +137,7 @@ def make_request(
     access_token=None,
     request=SynapseRequest,
     shorthand=True,
+    federation_auth_origin=None,
 ):
     """
     Make a web request using the given method and path, feed it the
@@ -150,9 +151,11 @@ def make_request(
         a dict.
         shorthand: Whether to try and be helpful and prefix the given URL
         with the usual REST API path, if it doesn't contain it.
+        federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+            Authorization header pretenting to be the given server name.
 
     Returns:
-        A synapse.http.site.SynapseRequest.
+        Tuple[synapse.http.site.SynapseRequest, channel]
     """
     if not isinstance(method, bytes):
         method = method.encode('ascii')
@@ -184,6 +187,11 @@ def make_request(
             b"Authorization", b"Bearer " + access_token.encode('ascii')
         )
 
+    if federation_auth_origin is not None:
+        req.requestHeaders.addRawHeader(
+            b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
+        )
+
     if content:
         req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
@@ -288,9 +296,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
             **kwargs
         )
 
-    pool.runWithConnection = runWithConnection
-    pool.runInteraction = runInteraction
-
     class ThreadPool:
         """
         Threadless thread pool.
@@ -316,8 +321,12 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
             return d
 
     clock.threadpool = ThreadPool()
-    pool.threadpool = ThreadPool()
-    pool.running = True
+
+    if pool:
+        pool.runWithConnection = runWithConnection
+        pool.runInteraction = runInteraction
+        pool.threadpool = ThreadPool()
+        pool.running = True
     return d
 
 
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 04f95c942f..00be1a8c21 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -17,7 +17,7 @@
 
 import json
 
-from mock import Mock, NonCallableMock
+from mock import Mock
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
@@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
 
         self.store = self.hs.get_datastore()
diff --git a/tests/unittest.py b/tests/unittest.py
index fac254ff10..ef31321bc8 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -262,6 +262,7 @@ class HomeserverTestCase(TestCase):
         access_token=None,
         request=SynapseRequest,
         shorthand=True,
+        federation_auth_origin=None,
     ):
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -275,15 +276,18 @@ class HomeserverTestCase(TestCase):
             a dict.
             shorthand: Whether to try and be helpful and prefix the given URL
             with the usual REST API path, if it doesn't contain it.
+            federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+                Authorization header pretenting to be the given server name.
 
         Returns:
-            A synapse.http.site.SynapseRequest.
+            Tuple[synapse.http.site.SynapseRequest, channel]
         """
         if isinstance(content, dict):
             content = json.dumps(content).encode('utf8')
 
         return make_request(
-            self.reactor, method, path, content, access_token, request, shorthand
+            self.reactor, method, path, content, access_token, request, shorthand,
+            federation_auth_origin,
         )
 
     def render(self, request):
diff --git a/tests/utils.py b/tests/utils.py
index cf49833a43..e4c42f9fa8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -29,7 +29,7 @@ from twisted.internet import defer, reactor
 from synapse.api.constants import EventTypes, RoomVersions
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.config.server import ServerConfig
-from synapse.federation.transport import server
+from synapse.federation.transport import server as federation_server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -150,6 +150,8 @@ def default_config(name):
     config.admin_contact = None
     config.rc_messages_per_second = 10000
     config.rc_message_burst_count = 10000
+    config.rc_registration_request_burst_count = 3.0
+    config.rc_registration_requests_per_second = 0.17
     config.saml2_enabled = False
     config.public_baseurl = None
     config.default_identity_server = None
@@ -200,6 +202,9 @@ def setup_test_homeserver(
     Args:
         cleanup_func : The function used to register a cleanup routine for
                        after the test.
+
+    Calling this method directly is deprecated: you should instead derive from
+    HomeserverTestCase.
     """
     if reactor is None:
         from twisted.internet import reactor
@@ -351,23 +356,27 @@ def setup_test_homeserver(
 
     fed = kargs.get("resource_for_federation", None)
     if fed:
-        server.register_servlets(
-            hs,
-            resource=fed,
-            authenticator=server.Authenticator(hs),
-            ratelimiter=FederationRateLimiter(
-                hs.get_clock(),
-                window_size=hs.config.federation_rc_window_size,
-                sleep_limit=hs.config.federation_rc_sleep_limit,
-                sleep_msec=hs.config.federation_rc_sleep_delay,
-                reject_limit=hs.config.federation_rc_reject_limit,
-                concurrent_requests=hs.config.federation_rc_concurrent,
-            ),
-        )
+        register_federation_servlets(hs, fed)
 
     defer.returnValue(hs)
 
 
+def register_federation_servlets(hs, resource):
+    federation_server.register_servlets(
+        hs,
+        resource=resource,
+        authenticator=federation_server.Authenticator(hs),
+        ratelimiter=FederationRateLimiter(
+            hs.get_clock(),
+            window_size=hs.config.federation_rc_window_size,
+            sleep_limit=hs.config.federation_rc_sleep_limit,
+            sleep_msec=hs.config.federation_rc_sleep_delay,
+            reject_limit=hs.config.federation_rc_reject_limit,
+            concurrent_requests=hs.config.federation_rc_concurrent,
+        ),
+    )
+
+
 def get_mock_call_args(pattern_func, mock_func):
     """ Return the arguments the mock function was called with interpreted
     by the pattern functions argument list.