diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 9e08eac0a5..c8994f416e 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -169,8 +169,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, 404)
@defer.inlineCallbacks
- def test_get_missing_room_keys(self):
- """Check that we get a 404 on querying missing room_keys
+ def test_get_missing_backup(self):
+ """Check that we get a 404 on querying missing backup
"""
res = None
try:
@@ -179,19 +179,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res = e.code
self.assertEqual(res, 404)
- # check we also get a 404 even if the version is valid
+ @defer.inlineCallbacks
+ def test_get_missing_room_keys(self):
+ """Check we get an empty response from an empty backup
+ """
version = yield self.handler.create_version(self.local_user, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
})
self.assertEqual(version, "1")
- res = None
- try:
- yield self.handler.get_room_keys(self.local_user, version)
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 404)
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest
@@ -345,17 +346,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check for bulk-delete
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(self.local_user, version)
- res = None
- try:
- yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 404)
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
# check for bulk-delete per room
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
@@ -364,17 +363,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
version,
room_id="!abc:matrix.org",
)
- res = None
- try:
- yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 404)
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
# check for bulk-delete per session
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
@@ -384,14 +381,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
room_id="!abc:matrix.org",
session_id="c0ff33",
)
- res = None
- try:
- yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 404)
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
new file mode 100644
index 0000000000..addc01ab7f
--- /dev/null
+++ b/tests/push/test_http.py
@@ -0,0 +1,159 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet.defer import Deferred
+
+from synapse.rest.client.v1 import admin, login, room
+
+from tests.unittest import HomeserverTestCase
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except Exception:
+ load_jinja2_templates = None
+
+
+class HTTPPusherTests(HomeserverTestCase):
+
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ self.push_attempts = []
+
+ m = Mock()
+
+ def post_json_get_json(url, body):
+ d = Deferred()
+ self.push_attempts.append((d, url, body))
+ return d
+
+ m.post_json_get_json = post_json_get_json
+
+ config = self.default_config()
+ config.start_pushers = True
+
+ hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+
+ return hs
+
+ def test_sends_http(self):
+ """
+ The HTTP pusher will send pushes for each message to a HTTP endpoint
+ when configured to do so.
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Get the stream ordering before it gets sent
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ last_stream_ordering = pushers[0]["last_stream_ordering"]
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # It hasn't succeeded yet, so the stream ordering shouldn't have moved
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+
+ # One push was attempted to be sent -- it'll be the first message
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
+ )
+
+ # Make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # The stream ordering has increased
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ last_stream_ordering = pushers[0]["last_stream_ordering"]
+
+ # Now it'll try and send the second push message, which will be the second one
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
+ )
+
+ # Make the second push succeed
+ self.push_attempts[1][0].callback({})
+ self.pump()
+
+ # The stream ordering has increased, again
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 41be5d5a1a..1688a741d1 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -28,8 +28,8 @@ ROOM_ID = "!room:blue"
def dict_equals(self, other):
- me = encode_canonical_json(self._event_dict)
- them = encode_canonical_json(other._event_dict)
+ me = encode_canonical_json(self.get_pdu_json())
+ them = encode_canonical_json(other.get_pdu_json())
return me == them
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
new file mode 100644
index 0000000000..df3f1cde6e
--- /dev/null
+++ b/tests/rest/client/test_consent.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from synapse.api.urls import ConsentURIBuilder
+from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.consent import consent_resource
+
+from tests import unittest
+from tests.server import render
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except Exception:
+ load_jinja2_templates = None
+
+
+class ConsentResourceTestCase(unittest.HomeserverTestCase):
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config.user_consent_version = "1"
+ config.public_baseurl = ""
+ config.form_secret = "123abc"
+
+ # Make some temporary templates...
+ temp_consent_path = self.mktemp()
+ os.mkdir(temp_consent_path)
+ os.mkdir(os.path.join(temp_consent_path, 'en'))
+ config.user_consent_template_dir = os.path.abspath(temp_consent_path)
+
+ with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
+ f.write("{{version}},{{has_consented}}")
+
+ with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f:
+ f.write("yay!")
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_accept_consent(self):
+ """
+ A user can use the consent form to accept the terms.
+ """
+ uri_builder = ConsentURIBuilder(self.hs.config)
+ resource = consent_resource.ConsentResource(self.hs)
+
+ # Register a user
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Fetch the consent page, to get the consent version
+ consent_uri = (
+ uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ + "&u=user"
+ )
+ request, channel = self.make_request(
+ "GET", consent_uri, access_token=access_token, shorthand=False
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Get the version from the body, and whether we've consented
+ version, consented = channel.result["body"].decode('ascii').split(",")
+ self.assertEqual(consented, "False")
+
+ # POST to the consent page, saying we've agreed
+ request, channel = self.make_request(
+ "POST",
+ consent_uri + "&v=" + version,
+ access_token=access_token,
+ shorthand=False,
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Fetch the consent page, to get the consent version -- it should have
+ # changed
+ request, channel = self.make_request(
+ "GET", consent_uri, access_token=access_token, shorthand=False
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Get the version from the body, and check that it's the version we
+ # agreed to, and that we've consented to it.
+ version, consented = channel.result["body"].decode('ascii').split(",")
+ self.assertEqual(consented, "True")
+ self.assertEqual(version, "1")
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 4c30c5f258..99b716f00a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,9 +15,11 @@
from mock import Mock
+from synapse.rest.client.v1 import admin, login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
@@ -65,3 +67,124 @@ class FilterTestCase(unittest.HomeserverTestCase):
["next_batch", "rooms", "account_data", "to_device", "device_lists"]
).issubset(set(channel.json_body.keys()))
)
+
+
+class SyncTypingTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def test_sync_backwards_typing(self):
+ """
+ If the typing serial goes backwards and the typing handler is then reset
+ (such as when the master restarts and sets the typing serial to 0), we
+ do not incorrectly return typing information that had a serial greater
+ than the now-reset serial.
+ """
+ typing_url = "/rooms/%s/typing/%s?access_token=%s"
+ sync_url = "/sync?timeout=3000000&access_token=%s&since=%s"
+
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ request, channel = self.make_request(
+ "GET", "/sync?access_token=%s" % (access_token,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Stop typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": false}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Should return immediately
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Reset typing serial back to 0, as if the master had.
+ typing = self.hs.get_typing_handler()
+ typing._latest_room_serial = 0
+
+ # Since it checks the state token, we need some state to update to
+ # invalidate the stream token.
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # This should time out! But it does not, because our stream token is
+ # ahead, and therefore it's saying the typing (that we've actually
+ # already seen) is new, since it's got a token above our new, now-reset
+ # stream token.
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Clear the typing information, so that it doesn't think everything is
+ # in the future.
+ typing._reset()
+
+ # Now it SHOULD fail as it never completes!
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.assertRaises(TimedOutException, self.render, request)
diff --git a/tests/server.py b/tests/server.py
index 819c854448..f63f33c94f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -21,6 +21,12 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+class TimedOutException(Exception):
+ """
+ A web query timed out.
+ """
+
+
@attr.s
class FakeChannel(object):
"""
@@ -98,10 +104,24 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
+def make_request(
+ method, path, content=b"", access_token=None, request=SynapseRequest, shorthand=True
+):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
+
+ Args:
+ method (bytes/unicode): The HTTP request method ("verb").
+ path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+ escaped UTF-8 & spaces and such).
+ content (bytes or dict): The body of the request. JSON-encoded, if
+ a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
+
+ Returns:
+ A synapse.http.site.SynapseRequest.
"""
if not isinstance(method, bytes):
method = method.encode('ascii')
@@ -109,8 +129,8 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
if not isinstance(path, bytes):
path = path.encode('ascii')
- # Decorate it to be the full path
- if not path.startswith(b"/_matrix"):
+ # Decorate it to be the full path, if we're using shorthand
+ if shorthand and not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
@@ -153,7 +173,7 @@ def wait_until_result(clock, request, timeout=100):
x += 1
if x > timeout:
- raise Exception("Timed out waiting for request to finish.")
+ raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index efd85ebe6c..2e073a3afc 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -544,8 +544,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- self.assertTrue(state_d.called)
- state_before = state_d.result
+ state_before = self.successResultOf(state_d)
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -599,6 +598,103 @@ class LexicographicalTestCase(unittest.TestCase):
self.assertEqual(["o", "l", "n", "m", "p"], res)
+class SimpleParamStateTestCase(unittest.TestCase):
+ def setUp(self):
+ # We build up a simple DAG.
+
+ event_map = {}
+
+ create_event = FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ).to_event([], [])
+ event_map[create_event.event_id] = create_event
+
+ alice_member = FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event([create_event.event_id], [create_event.event_id])
+ event_map[alice_member.event_id] = alice_member
+
+ join_rules = FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ).to_event(
+ auth_events=[create_event.event_id, alice_member.event_id],
+ prev_events=[alice_member.event_id],
+ )
+ event_map[join_rules.event_id] = join_rules
+
+ # Bob and Charlie join at the same time, so there is a fork
+ bob_member = FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event(
+ auth_events=[create_event.event_id, join_rules.event_id],
+ prev_events=[join_rules.event_id],
+ )
+ event_map[bob_member.event_id] = bob_member
+
+ charlie_member = FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event(
+ auth_events=[create_event.event_id, join_rules.event_id],
+ prev_events=[join_rules.event_id],
+ )
+ event_map[charlie_member.event_id] = charlie_member
+
+ self.event_map = event_map
+ self.create_event = create_event
+ self.alice_member = alice_member
+ self.join_rules = join_rules
+ self.bob_member = bob_member
+ self.charlie_member = charlie_member
+
+ self.state_at_bob = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, bob_member]
+ }
+
+ self.state_at_charlie = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, charlie_member]
+ }
+
+ self.expected_combined_state = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, bob_member, charlie_member]
+ }
+
+ def test_event_map_none(self):
+ # Test that we correctly handle passing `None` as the event_map
+
+ state_d = resolve_events_with_store(
+ [self.state_at_bob, self.state_at_charlie],
+ event_map=None,
+ state_res_store=TestStateResolutionStore(self.event_map),
+ )
+
+ state = self.successResultOf(state_d)
+
+ self.assert_dict(self.expected_combined_state, state)
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -657,7 +753,7 @@ class TestStateResolutionStore(object):
result.add(event_id)
event = self.event_map[event_id]
- for aid, _ in event.auth_events:
+ for aid in event.auth_event_ids():
stack.append(aid)
return list(result)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 952a0a7b51..e1a34ccffd 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -112,7 +112,7 @@ class MessageAcceptTests(unittest.TestCase):
"origin_server_ts": 1,
"type": "m.room.message",
"origin": "test.serv",
- "content": "hewwo?",
+ "content": {"body": "hewwo?"},
"auth_events": [],
"prev_events": [("two:test.serv", {}), (most_recent, {})],
}
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
new file mode 100644
index 0000000000..0b71c6feb9
--- /dev/null
+++ b/tests/test_terms_auth.py
@@ -0,0 +1,124 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import six
+from mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import make_request
+
+
+class TermsTestCase(unittest.HomeserverTestCase):
+ servlets = [register_servlets]
+
+ def prepare(self, reactor, clock, hs):
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+ self.url = "/_matrix/client/r0/register"
+ self.registration_handler = Mock()
+ self.auth_handler = Mock()
+ self.device_handler = Mock()
+ hs.config.enable_registration = True
+ hs.config.registrations_require_3pid = []
+ hs.config.auto_join_rooms = []
+ hs.config.enable_registration_captcha = False
+
+ def test_ui_auth(self):
+ self.hs.config.user_consent_at_registration = True
+ self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
+ self.hs.config.public_baseurl = "https://example.org"
+ self.hs.config.user_consent_version = "1.0"
+
+ # Do a UI auth request
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ self.assertTrue(channel.json_body is not None)
+ self.assertIsInstance(channel.json_body["session"], six.text_type)
+
+ self.assertIsInstance(channel.json_body["flows"], list)
+ for flow in channel.json_body["flows"]:
+ self.assertIsInstance(flow["stages"], list)
+ self.assertTrue(len(flow["stages"]) > 0)
+ self.assertEquals(flow["stages"][-1], "m.login.terms")
+
+ expected_params = {
+ "m.login.terms": {
+ "policies": {
+ "privacy_policy": {
+ "en": {
+ "name": "My Cool Privacy Policy",
+ "url": "https://example.org/_matrix/consent?v=1.0",
+ },
+ "version": "1.0"
+ },
+ },
+ },
+ }
+ self.assertIsInstance(channel.json_body["params"], dict)
+ self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+
+ # We have to complete the dummy auth stage before completing the terms stage
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.dummy",
+ },
+ }
+ )
+
+ self.registration_handler.check_username = Mock(return_value=True)
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ self.render(request)
+
+ # We don't bother checking that the response is correct - we'll leave that to
+ # other tests. We just want to make sure we're on the right path.
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ # Finish the UI auth for terms
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.terms",
+ },
+ }
+ )
+ request, channel = make_request(b"POST", self.url, request_data)
+ self.render(request)
+
+ # We're interested in getting a response that looks like a successful
+ # registration, not so much that the details are exactly what we want.
+
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.assertTrue(channel.json_body is not None)
+ self.assertIsInstance(channel.json_body["user_id"], six.text_type)
+ self.assertIsInstance(channel.json_body["access_token"], six.text_type)
+ self.assertIsInstance(channel.json_body["device_id"], six.text_type)
diff --git a/tests/unittest.py b/tests/unittest.py
index 4d40bdb6a5..5e35c943d7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -258,7 +258,13 @@ class HomeserverTestCase(TestCase):
"""
def make_request(
- self, method, path, content=b"", access_token=None, request=SynapseRequest
+ self,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
):
"""
Create a SynapseRequest at the path using the method and containing the
@@ -270,6 +276,8 @@ class HomeserverTestCase(TestCase):
escaped UTF-8 & spaces and such).
content (bytes or dict): The body of the request. JSON-encoded, if
a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
Returns:
A synapse.http.site.SynapseRequest.
@@ -277,7 +285,7 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
- return make_request(method, path, content, access_token, request)
+ return make_request(method, path, content, access_token, request, shorthand)
def render(self, request):
"""
diff --git a/tests/utils.py b/tests/utils.py
index 565bb60d08..67ab916f30 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -123,6 +123,8 @@ def default_config(name):
config.user_directory_search_all_users = False
config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None
+ config.user_consent_at_registration = False
+ config.user_consent_policy_name = "Privacy Policy"
config.media_storage_providers = []
config.autocreate_auto_join_rooms = True
config.auto_join_rooms = []
|