diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
new file mode 100644
index 0000000000..4294bbec2a
--- /dev/null
+++ b/tests/rest/client/test_consent.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from synapse.api.urls import ConsentURIBuilder
+from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.consent import consent_resource
+
+from tests import unittest
+from tests.server import render
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except Exception:
+ load_jinja2_templates = None
+
+
+class ConsentResourceTestCase(unittest.HomeserverTestCase):
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config.user_consent_version = "1"
+ config.public_baseurl = ""
+ config.form_secret = "123abc"
+
+ # Make some temporary templates...
+ temp_consent_path = self.mktemp()
+ os.mkdir(temp_consent_path)
+ os.mkdir(os.path.join(temp_consent_path, 'en'))
+ config.user_consent_template_dir = os.path.abspath(temp_consent_path)
+
+ with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
+ f.write("{{version}},{{has_consented}}")
+
+ with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f:
+ f.write("yay!")
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_render_public_consent(self):
+ """You can observe the terms form without specifying a user"""
+ resource = consent_resource.ConsentResource(self.hs)
+ request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ def test_accept_consent(self):
+ """
+ A user can use the consent form to accept the terms.
+ """
+ uri_builder = ConsentURIBuilder(self.hs.config)
+ resource = consent_resource.ConsentResource(self.hs)
+
+ # Register a user
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Fetch the consent page, to get the consent version
+ consent_uri = (
+ uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ + "&u=user"
+ )
+ request, channel = self.make_request(
+ "GET", consent_uri, access_token=access_token, shorthand=False
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Get the version from the body, and whether we've consented
+ version, consented = channel.result["body"].decode('ascii').split(",")
+ self.assertEqual(consented, "False")
+
+ # POST to the consent page, saying we've agreed
+ request, channel = self.make_request(
+ "POST",
+ consent_uri + "&v=" + version,
+ access_token=access_token,
+ shorthand=False,
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Fetch the consent page, to get the consent version -- it should have
+ # changed
+ request, channel = self.make_request(
+ "GET", consent_uri, access_token=access_token, shorthand=False
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Get the version from the body, and check that it's the version we
+ # agreed to, and that we've consented to it.
+ version, consented = channel.result["body"].decode('ascii').split(",")
+ self.assertEqual(consented, "True")
+ self.assertEqual(version, "1")
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 1a553fa3f9..407bf0ac4c 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -19,24 +19,18 @@ import json
from mock import Mock
-from synapse.http.server import JsonResource
+from synapse.api.constants import UserTypes
from synapse.rest.client.v1.admin import register_servlets
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
-class UserRegisterTestCase(unittest.TestCase):
- def setUp(self):
+class UserRegisterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [register_servlets]
+
+ def make_homeserver(self, reactor, clock):
- self.clock = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock()
@@ -50,17 +44,14 @@ class UserRegisterTestCase(unittest.TestCase):
self.secrets = Mock()
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
+ self.hs = self.setup_test_homeserver()
self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
- self.resource = JsonResource(self.hs)
- register_servlets(self.hs, self.resource)
+ return self.hs
def test_disabled(self):
"""
@@ -69,8 +60,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
self.hs.config.registration_shared_secret = None
- request, channel = make_request("POST", self.url, b'{}')
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, b'{}')
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -87,8 +78,8 @@ class UserRegisterTestCase(unittest.TestCase):
self.hs.get_secrets = Mock(return_value=secrets)
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -97,25 +88,25 @@ class UserRegisterTestCase(unittest.TestCase):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
# 59 seconds
- self.clock.advance(59)
+ self.reactor.advance(59)
body = json.dumps({"nonce": nonce})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds
- self.clock.advance(2)
+ self.reactor.advance(2)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -124,8 +115,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -141,8 +132,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -152,12 +143,14 @@ class UserRegisterTestCase(unittest.TestCase):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+ want_mac.update(
+ nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support"
+ )
want_mac = want_mac.hexdigest()
body = json.dumps(
@@ -166,11 +159,12 @@ class UserRegisterTestCase(unittest.TestCase):
"username": "bob",
"password": "abc123",
"admin": True,
+ "user_type": UserTypes.SUPPORT,
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -179,12 +173,14 @@ class UserRegisterTestCase(unittest.TestCase):
"""
A valid unrecognised nonce.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+ want_mac.update(
+ nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin"
+ )
want_mac = want_mac.hexdigest()
body = json.dumps(
@@ -196,15 +192,15 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -212,13 +208,13 @@ class UserRegisterTestCase(unittest.TestCase):
def test_missing_parts(self):
"""
Synapse will complain if you don't give nonce, username, password, and
- mac. Admin is optional. Additional checks are done for length and
- type.
+ mac. Admin and user_types are optional. Additional checks are done for length
+ and type.
"""
def nonce():
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
return channel.json_body["nonce"]
#
@@ -227,8 +223,8 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
@@ -239,52 +235,52 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
#
- # Username checks
+ # Password checks
#
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
@@ -293,16 +289,33 @@ class UserRegisterTestCase(unittest.TestCase):
body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
+
+ #
+ # user_type check
+ #
+
+ # Invalid user_type
+ body = json.dumps({
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid"}
+ )
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid user type', channel.json_body["error"])
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 956f7fc4c4..483bebc832 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -16,64 +16,49 @@
""" Tests REST events for /events paths."""
from mock import Mock, NonCallableMock
-from six import PY3
-from twisted.internet import defer
+from synapse.rest.client.v1 import admin, events, login, room
-from ....utils import MockHttpResource, setup_test_homeserver
-from .utils import RestTestCase
+from tests import unittest
-PATH_PREFIX = "/_matrix/client/api/v1"
-
-class EventStreamPermissionsTestCase(RestTestCase):
+class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
""" Tests event streaming (GET /events). """
- if PY3:
- skip = "Skip on Py3 until ported to use not V1 only register."
+ servlets = [
+ events.register_servlets,
+ room.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
- @defer.inlineCallbacks
- def setUp(self):
- import synapse.rest.client.v1.events
- import synapse.rest.client.v1_only.register
- import synapse.rest.client.v1.room
+ def make_homeserver(self, reactor, clock):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
+ config = self.default_config()
+ config.enable_registration_captcha = False
+ config.enable_registration = True
+ config.auto_join_rooms = []
- hs = yield setup_test_homeserver(
- self.addCleanup,
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
+ hs = self.setup_test_homeserver(
+ config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
- hs.config.enable_registration_captcha = False
- hs.config.enable_registration = True
- hs.config.auto_join_rooms = []
hs.get_handlers().federation_handler = Mock()
- synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
- synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ return hs
+
+ def prepare(self, hs, reactor, clock):
# register an account
- self.user_id = "sid1"
- response = yield self.register(self.user_id)
- self.token = response["access_token"]
- self.user_id = response["user_id"]
+ self.user_id = self.register_user("sid1", "pass")
+ self.token = self.login(self.user_id, "pass")
# register a 2nd account
- self.other_user = "other1"
- response = yield self.register(self.other_user)
- self.other_token = response["access_token"]
- self.other_user = response["user_id"]
+ self.other_user = self.register_user("other2", "pass")
+ self.other_token = self.login(self.other_user, "pass")
- def tearDown(self):
- pass
-
- @defer.inlineCallbacks
def test_stream_basic_permissions(self):
# invalid token, expect 401
# note: this is in violation of the original v1 spec, which expected
@@ -81,34 +66,37 @@ class EventStreamPermissionsTestCase(RestTestCase):
# implementation is now part of the r0 implementation, the newer
# behaviour is used instead to be consistent with the r0 spec.
# see issue #2602
- (code, response) = yield self.mock_resource.trigger_get(
- "/events?access_token=%s" % ("invalid" + self.token,)
+ request, channel = self.make_request(
+ "GET", "/events?access_token=%s" % ("invalid" + self.token,)
)
- self.assertEquals(401, code, msg=str(response))
+ self.render(request)
+ self.assertEquals(channel.code, 401, msg=channel.result)
# valid token, expect content
- (code, response) = yield self.mock_resource.trigger_get(
- "/events?access_token=%s&timeout=0" % (self.token,)
+ request, channel = self.make_request(
+ "GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
- self.assertEquals(200, code, msg=str(response))
- self.assertTrue("chunk" in response)
- self.assertTrue("start" in response)
- self.assertTrue("end" in response)
+ self.render(request)
+ self.assertEquals(channel.code, 200, msg=channel.result)
+ self.assertTrue("chunk" in channel.json_body)
+ self.assertTrue("start" in channel.json_body)
+ self.assertTrue("end" in channel.json_body)
- @defer.inlineCallbacks
def test_stream_room_permissions(self):
- room_id = yield self.create_room_as(self.other_user, tok=self.other_token)
- yield self.send(room_id, tok=self.other_token)
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
+ self.helper.send(room_id, tok=self.other_token)
# invited to room (expect no content for room)
- yield self.invite(
+ self.helper.invite(
room_id, src=self.other_user, targ=self.user_id, tok=self.other_token
)
- (code, response) = yield self.mock_resource.trigger_get(
- "/events?access_token=%s&timeout=0" % (self.token,)
+ # valid token, expect content
+ request, channel = self.make_request(
+ "GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
- self.assertEquals(200, code, msg=str(response))
+ self.render(request)
+ self.assertEquals(channel.code, 200, msg=channel.result)
# We may get a presence event for ourselves down
self.assertEquals(
@@ -116,7 +104,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
len(
[
c
- for c in response["chunk"]
+ for c in channel.json_body["chunk"]
if not (
c.get("type") == "m.presence"
and c["content"].get("user_id") == self.user_id
@@ -126,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
)
# joined room (expect all content for room)
- yield self.join(room=room_id, user=self.user_id, tok=self.token)
+ self.helper.join(room=room_id, user=self.user_id, tok=self.token)
# left to room (expect no content for room)
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
deleted file mode 100644
index 6b7ff813d5..0000000000
--- a/tests/rest/client/v1/test_register.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-
-from mock import Mock
-from six import PY3
-
-from twisted.test.proto_helpers import MemoryReactorClock
-
-from synapse.http.server import JsonResource
-from synapse.rest.client.v1_only.register import register_servlets
-from synapse.util import Clock
-
-from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
-
-
-class CreateUserServletTestCase(unittest.TestCase):
- """
- Tests for CreateUserRestServlet.
- """
-
- if PY3:
- skip = "Not ported to Python 3."
-
- def setUp(self):
- self.registration_handler = Mock()
-
- self.appservice = Mock(sender="@as:test")
- self.datastore = Mock(
- get_app_service_by_token=Mock(return_value=self.appservice)
- )
-
- handlers = Mock(registration_handler=self.registration_handler)
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
-
- self.hs = self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
- self.hs.get_datastore = Mock(return_value=self.datastore)
- self.hs.get_handlers = Mock(return_value=handlers)
-
- def test_POST_createuser_with_valid_user(self):
-
- res = JsonResource(self.hs)
- register_servlets(self.hs, res)
-
- request_data = json.dumps(
- {
- "localpart": "someone",
- "displayname": "someone interesting",
- "duration_seconds": 200,
- }
- )
-
- url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
-
- user_id = "@someone:interesting"
- token = "my token"
-
- self.registration_handler.get_or_create_user = Mock(
- return_value=(user_id, token)
- )
-
- request, channel = make_request(b"POST", url, request_data)
- render(request, res, self.clock)
-
- self.assertEquals(channel.result["code"], b"200")
-
- det_data = {
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- }
- self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 530dc8ba6d..9c401bf300 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -169,7 +169,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
request, channel = make_request(
- "POST", path, json.dumps(content).encode('utf8')
+ self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())
@@ -217,7 +217,9 @@ class RestHelper(object):
data = {"membership": membership}
- request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
+ )
render(request, self.resource, self.hs.get_reactor())
@@ -228,18 +230,6 @@ class RestHelper(object):
self.auth_user_id = temp_id
- @defer.inlineCallbacks
- def register(self, user_id):
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/_matrix/client/r0/register",
- json.dumps(
- {"user": user_id, "password": "test", "type": "m.login.password"}
- ),
- )
- self.assertEquals(200, code)
- defer.returnValue(response)
-
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -251,7 +241,9 @@ class RestHelper(object):
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
+ )
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
new file mode 100644
index 0000000000..7fa120a10f
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import admin
+from synapse.rest.client.v2_alpha import auth, register
+
+from tests import unittest
+
+
+class FallbackAuthTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ auth.register_servlets,
+ admin.register_servlets,
+ register.register_servlets,
+ ]
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+
+ config.enable_registration_captcha = True
+ config.recaptcha_public_key = "brokencake"
+ config.registrations_require_3pid = []
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ auth_handler = hs.get_auth_handler()
+
+ self.recaptcha_attempts = []
+
+ def _recaptcha(authdict, clientip):
+ self.recaptcha_attempts.append((authdict, clientip))
+ return succeed(True)
+
+ auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+
+ @unittest.INFO
+ def test_fallback_captcha(self):
+
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.render(request)
+
+ # Returns a 401 as per the spec
+ self.assertEqual(request.code, 401)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ request, channel = self.make_request(
+ "POST",
+ "auth/m.login.recaptcha/fallback/web?session="
+ + session
+ + "&g-recaptcha-response=a",
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ # The recaptcha handler is called with the response given
+ self.assertEqual(len(self.recaptcha_attempts), 1)
+ self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+
+ # Now we have fufilled the recaptcha fallback step, we can then send a
+ # request to the register API with the session in the authdict.
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session}}
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
new file mode 100644
index 0000000000..d3d43970fb
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.constants import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v2_alpha import capabilities
+
+from tests import unittest
+
+
+class CapabilitiesTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ capabilities.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.url = b"/_matrix/client/r0/capabilities"
+ hs = self.setup_test_homeserver()
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_check_auth_required(self):
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+
+ self.assertEqual(channel.code, 401)
+
+ def test_get_room_version_capabilities(self):
+ self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ request, channel = self.make_request("GET", self.url, access_token=access_token)
+ self.render(request)
+ capabilities = channel.json_body['capabilities']
+
+ self.assertEqual(channel.code, 200)
+ for room_version in capabilities['m.room_versions']['available'].keys():
+ self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+ self.assertEqual(
+ DEFAULT_ROOM_VERSION, capabilities['m.room_versions']['default']
+ )
+
+ def test_get_change_password_capabilities(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.login(user, password)
+
+ request, channel = self.make_request("GET", self.url, access_token=access_token)
+ self.render(request)
+ capabilities = channel.json_body['capabilities']
+
+ self.assertEqual(channel.code, 200)
+
+ # Test case where password is handled outside of Synapse
+ self.assertTrue(capabilities['m.change_password']['enabled'])
+ self.get_success(self.store.user_set_password_hash(user, None))
+ request, channel = self.make_request("GET", self.url, access_token=access_token)
+ self.render(request)
+ capabilities = channel.json_body['capabilities']
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities['m.change_password']['enabled'])
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 6a886ee3b8..f42a8efbf4 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,84 +13,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.types
from synapse.api.errors import Codes
-from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import filter
-from synapse.types import UserID
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock as MemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
PATH_PREFIX = "/_matrix/client/v2_alpha"
-class FilterTestCase(unittest.TestCase):
+class FilterTestCase(unittest.HomeserverTestCase):
- USER_ID = "@apple:test"
+ user_id = "@apple:test"
+ hijack_auth = True
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
- TO_REGISTER = [filter]
+ servlets = [filter.register_servlets]
- def setUp(self):
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
-
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
-
- self.auth = self.hs.get_auth()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.USER_ID),
- "token_id": 1,
- "is_guest": False,
- }
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return synapse.types.create_requester(
- UserID.from_string(self.USER_ID), 1, False, None
- )
-
- self.auth.get_user_by_access_token = get_user_by_access_token
- self.auth.get_user_by_req = get_user_by_req
-
- self.store = self.hs.get_datastore()
- self.filtering = self.hs.get_filtering()
- self.resource = JsonResource(self.hs)
-
- for r in self.TO_REGISTER:
- r.register_servlets(self.hs, self.resource)
+ def prepare(self, reactor, clock, hs):
+ self.filtering = hs.get_filtering()
+ self.store = hs.get_datastore()
def test_add_filter(self):
- request, channel = make_request(
+ request, channel = self.make_request(
"POST",
- "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
- self.clock.advance(0)
+ self.pump()
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self):
- request, channel = make_request(
+ request, channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- request, channel = make_request(
+ request, channel = self.make_request(
"POST",
- "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
@@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
filter_id = self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
- self.clock.advance(1)
+ self.reactor.advance(1)
filter_id = filter_id.result
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self):
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
def test_get_filter_invalid_id(self):
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1c128e81f5..906b348d3e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,83 +1,51 @@
import json
-from mock import Mock
-
-from twisted.python import failure
-from twisted.test.proto_helpers import MemoryReactorClock
-
-from synapse.api.errors import InteractiveAuthIncompleteError
-from synapse.http.server import JsonResource
+from synapse.api.constants import LoginType
+from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha.register import register_servlets
-from synapse.util import Clock
from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
-class RegisterRestServletTestCase(unittest.TestCase):
- def setUp(self):
+class RegisterRestServletTestCase(unittest.HomeserverTestCase):
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
- self.url = b"/_matrix/client/r0/register"
+ servlets = [register_servlets]
- self.appservice = None
- self.auth = Mock(
- get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
- )
+ def make_homeserver(self, reactor, clock):
- self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
- self.auth_handler = Mock(
- check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
- get_session_data=Mock(return_value=None),
- )
- self.registration_handler = Mock()
- self.identity_handler = Mock()
- self.login_handler = Mock()
- self.device_handler = Mock()
- self.device_handler.check_device_registered = Mock(return_value="FAKE")
-
- self.datastore = Mock(return_value=Mock())
- self.datastore.get_current_state_deltas = Mock(return_value=[])
-
- # do the dance to hook it up to the hs global
- self.handlers = Mock(
- registration_handler=self.registration_handler,
- identity_handler=self.identity_handler,
- login_handler=self.login_handler,
- )
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
- self.hs.get_auth = Mock(return_value=self.auth)
- self.hs.get_handlers = Mock(return_value=self.handlers)
- self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
- self.hs.get_device_handler = Mock(return_value=self.device_handler)
- self.hs.get_datastore = Mock(return_value=self.datastore)
+ self.url = b"/_matrix/client/r0/register"
+
+ self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
+ self.hs.config.enable_registration_captcha = False
- self.resource = JsonResource(self.hs)
- register_servlets(self.hs, self.resource)
+ return self.hs
def test_POST_appservice_registration_valid(self):
- user_id = "@kermit:muppet"
- token = "kermits_access_token"
- self.appservice = {"id": "1234"}
- self.registration_handler.appservice_register = Mock(return_value=user_id)
- self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
- request_data = json.dumps({"username": "kermit"})
+ user_id = "@as_user_kermit:test"
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token, self.hs.config.hostname,
+ id="1234",
+ namespaces={
+ "users": [{"regex": r"@as_user.*", "exclusive": True}],
+ },
+ )
+
+ self.hs.get_datastore().services_cache.append(appservice)
+ request_data = json.dumps({"username": "as_user_kermit"})
- request, channel = make_request(
+ request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
"user_id": user_id,
- "access_token": token,
"home_server": self.hs.hostname,
}
self.assertDictContainsSubset(det_data, channel.json_body)
@@ -85,81 +53,69 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
request_data = json.dumps({"username": "kermit"})
- request, channel = make_request(
+ request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self):
- user_id = "@kermit:muppet"
- token = "kermits_access_token"
+ user_id = "@kermit:test"
device_id = "frogfone"
- request_data = json.dumps(
- {"username": "kermit", "password": "monkey", "device_id": device_id}
- )
- self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
- self.registration_handler.register = Mock(return_value=(user_id, None))
- self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
- self.device_handler.check_device_registered = Mock(return_value=device_id)
-
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ "device_id": device_id,
+ "auth": {"type": LoginType.DUMMY},
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
det_data = {
"user_id": user_id,
- "access_token": token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
- self.auth_handler.get_login_tuple_for_user_id(
- user_id, device_id=device_id, initial_device_display_name=None
- )
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
- self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
- self.registration_handler.register = Mock(return_value=("@user:id", "t"))
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
def test_POST_guest_registration(self):
- user_id = "a@b"
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
- self.registration_handler.register = Mock(return_value=(user_id, None))
- request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ self.render(request)
det_data = {
- "user_id": user_id,
"home_server": self.hs.hostname,
"device_id": "guest_device",
}
@@ -169,8 +125,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
- request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 4c30c5f258..99b716f00a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,9 +15,11 @@
from mock import Mock
+from synapse.rest.client.v1 import admin, login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
@@ -65,3 +67,124 @@ class FilterTestCase(unittest.HomeserverTestCase):
["next_batch", "rooms", "account_data", "to_device", "device_lists"]
).issubset(set(channel.json_body.keys()))
)
+
+
+class SyncTypingTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def test_sync_backwards_typing(self):
+ """
+ If the typing serial goes backwards and the typing handler is then reset
+ (such as when the master restarts and sets the typing serial to 0), we
+ do not incorrectly return typing information that had a serial greater
+ than the now-reset serial.
+ """
+ typing_url = "/rooms/%s/typing/%s?access_token=%s"
+ sync_url = "/sync?timeout=3000000&access_token=%s&since=%s"
+
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ request, channel = self.make_request(
+ "GET", "/sync?access_token=%s" % (access_token,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Stop typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": false}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Should return immediately
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Reset typing serial back to 0, as if the master had.
+ typing = self.hs.get_typing_handler()
+ typing._latest_room_serial = 0
+
+ # Since it checks the state token, we need some state to update to
+ # invalidate the stream token.
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # This should time out! But it does not, because our stream token is
+ # ahead, and therefore it's saying the typing (that we've actually
+ # already seen) is new, since it's got a token above our new, now-reset
+ # stream token.
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Clear the typing information, so that it doesn't think everything is
+ # in the future.
+ typing._reset()
+
+ # Now it SHOULD fail as it never completes!
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.assertRaises(TimedOutException, self.render, request)
|