summary refs log tree commit diff
path: root/tests/rest/client/v2_alpha
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v2_alpha')
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py95
-rw-r--r--tests/rest/client/v2_alpha/test_register.py52
2 files changed, 52 insertions, 95 deletions
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 6a886ee3b8..f42a8efbf4 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,84 +13,47 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import synapse.types
 from synapse.api.errors import Codes
-from synapse.http.server import JsonResource
 from synapse.rest.client.v2_alpha import filter
-from synapse.types import UserID
-from synapse.util import Clock
 
 from tests import unittest
-from tests.server import (
-    ThreadedMemoryReactorClock as MemoryReactorClock,
-    make_request,
-    render,
-    setup_test_homeserver,
-)
 
 PATH_PREFIX = "/_matrix/client/v2_alpha"
 
 
-class FilterTestCase(unittest.TestCase):
+class FilterTestCase(unittest.HomeserverTestCase):
 
-    USER_ID = "@apple:test"
+    user_id = "@apple:test"
+    hijack_auth = True
     EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
     EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
-    TO_REGISTER = [filter]
+    servlets = [filter.register_servlets]
 
-    def setUp(self):
-        self.clock = MemoryReactorClock()
-        self.hs_clock = Clock(self.clock)
-
-        self.hs = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
-        )
-
-        self.auth = self.hs.get_auth()
-
-        def get_user_by_access_token(token=None, allow_guest=False):
-            return {
-                "user": UserID.from_string(self.USER_ID),
-                "token_id": 1,
-                "is_guest": False,
-            }
-
-        def get_user_by_req(request, allow_guest=False, rights="access"):
-            return synapse.types.create_requester(
-                UserID.from_string(self.USER_ID), 1, False, None
-            )
-
-        self.auth.get_user_by_access_token = get_user_by_access_token
-        self.auth.get_user_by_req = get_user_by_req
-
-        self.store = self.hs.get_datastore()
-        self.filtering = self.hs.get_filtering()
-        self.resource = JsonResource(self.hs)
-
-        for r in self.TO_REGISTER:
-            r.register_servlets(self.hs, self.resource)
+    def prepare(self, reactor, clock, hs):
+        self.filtering = hs.get_filtering()
+        self.store = hs.get_datastore()
 
     def test_add_filter(self):
-        request, channel = make_request(
+        request, channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+            "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEqual(channel.json_body, {"filter_id": "0"})
         filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
-        self.clock.advance(0)
+        self.pump()
         self.assertEquals(filter.result, self.EXAMPLE_FILTER)
 
     def test_add_filter_for_other_user(self):
-        request, channel = make_request(
+        request, channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
             self.EXAMPLE_FILTER_JSON,
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"403")
         self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
     def test_add_filter_non_local_user(self):
         _is_mine = self.hs.is_mine
         self.hs.is_mine = lambda target_user: False
-        request, channel = make_request(
+        request, channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+            "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.hs.is_mine = _is_mine
         self.assertEqual(channel.result["code"], b"403")
@@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
         filter_id = self.filtering.add_user_filter(
             user_localpart="apple", user_filter=self.EXAMPLE_FILTER
         )
-        self.clock.advance(1)
+        self.reactor.advance(1)
         filter_id = filter_id.result
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
 
     def test_get_filter_non_existant(self):
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"400")
         self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
     # Currently invalid params do not have an appropriate errcode
     # in errors.py
     def test_get_filter_invalid_id(self):
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"400")
 
     # No ID also returns an invalid_id error
     def test_get_filter_no_id(self):
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1c128e81f5..753d5c3e80 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -3,22 +3,19 @@ import json
 from mock import Mock
 
 from twisted.python import failure
-from twisted.test.proto_helpers import MemoryReactorClock
 
 from synapse.api.errors import InteractiveAuthIncompleteError
-from synapse.http.server import JsonResource
 from synapse.rest.client.v2_alpha.register import register_servlets
-from synapse.util import Clock
 
 from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
 
 
-class RegisterRestServletTestCase(unittest.TestCase):
-    def setUp(self):
+class RegisterRestServletTestCase(unittest.HomeserverTestCase):
+
+    servlets = [register_servlets]
+
+    def make_homeserver(self, reactor, clock):
 
-        self.clock = MemoryReactorClock()
-        self.hs_clock = Clock(self.clock)
         self.url = b"/_matrix/client/r0/register"
 
         self.appservice = None
@@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             identity_handler=self.identity_handler,
             login_handler=self.login_handler,
         )
-        self.hs = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
-        )
+        self.hs = self.setup_test_homeserver()
         self.hs.get_auth = Mock(return_value=self.auth)
         self.hs.get_handlers = Mock(return_value=self.handlers)
         self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.config.registrations_require_3pid = []
         self.hs.config.auto_join_rooms = []
 
-        self.resource = JsonResource(self.hs)
-        register_servlets(self.hs, self.resource)
+        return self.hs
 
     def test_POST_appservice_registration_valid(self):
         user_id = "@kermit:muppet"
@@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
         request_data = json.dumps({"username": "kermit"})
 
-        request, channel = make_request(
+        request, channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
         det_data = {
@@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase):
     def test_POST_appservice_registration_invalid(self):
         self.appservice = None  # no application service exists
         request_data = json.dumps({"username": "kermit"})
-        request, channel = make_request(
+        request, channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
 
     def test_POST_bad_password(self):
         request_data = json.dumps({"username": "kermit", "password": 666})
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid password")
 
     def test_POST_bad_username(self):
         request_data = json.dumps({"username": 777, "password": "monkey"})
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
         self.device_handler.check_device_registered = Mock(return_value=device_id)
 
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         det_data = {
             "user_id": user_id,
@@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
         self.registration_handler.register = Mock(return_value=("@user:id", "t"))
 
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.config.allow_guest_access = True
         self.registration_handler.register = Mock(return_value=(user_id, None))
 
-        request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
 
         det_data = {
             "user_id": user_id,
@@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
     def test_POST_disabled_guest_registration(self):
         self.hs.config.allow_guest_access = False
 
-        request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")