diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_e2e_room_keys.py | 79 | ||||
-rw-r--r-- | tests/push/__init__.py | 0 | ||||
-rw-r--r-- | tests/push/test_email.py | 148 | ||||
-rw-r--r-- | tests/push/test_http.py | 159 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 4 | ||||
-rw-r--r-- | tests/rest/client/test_consent.py | 111 | ||||
-rw-r--r-- | tests/rest/client/v1/test_admin.py | 116 | ||||
-rw-r--r-- | tests/rest/client/v1/test_register.py | 10 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 106 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 22 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_filter.py | 95 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 52 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_sync.py | 123 | ||||
-rw-r--r-- | tests/rest/media/v1/test_url_preview.py | 164 | ||||
-rw-r--r-- | tests/server.py | 52 | ||||
-rw-r--r-- | tests/state/test_v2.py | 102 | ||||
-rw-r--r-- | tests/storage/test_end_to_end_keys.py | 15 | ||||
-rw-r--r-- | tests/test_federation.py | 2 | ||||
-rw-r--r-- | tests/test_mau.py | 37 | ||||
-rw-r--r-- | tests/test_server.py | 12 | ||||
-rw-r--r-- | tests/test_terms_auth.py | 123 | ||||
-rw-r--r-- | tests/unittest.py | 29 | ||||
-rw-r--r-- | tests/utils.py | 2 |
23 files changed, 1294 insertions, 269 deletions
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 9e08eac0a5..c8994f416e 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -169,8 +169,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, 404) @defer.inlineCallbacks - def test_get_missing_room_keys(self): - """Check that we get a 404 on querying missing room_keys + def test_get_missing_backup(self): + """Check that we get a 404 on querying missing backup """ res = None try: @@ -179,19 +179,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = e.code self.assertEqual(res, 404) - # check we also get a 404 even if the version is valid + @defer.inlineCallbacks + def test_get_missing_room_keys(self): + """Check we get an empty response from an empty backup + """ version = yield self.handler.create_version(self.local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", }) self.assertEqual(version, "1") - res = None - try: - yield self.handler.get_room_keys(self.local_user, version) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 404) + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertDictEqual(res, { + "rooms": {} + }) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest @@ -345,17 +346,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check for bulk-delete yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys(self.local_user, version) - res = None - try: - yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 404) + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + self.assertDictEqual(res, { + "rooms": {} + }) # check for bulk-delete per room yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -364,17 +363,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): version, room_id="!abc:matrix.org", ) - res = None - try: - yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 404) + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + self.assertDictEqual(res, { + "rooms": {} + }) # check for bulk-delete per session yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -384,14 +381,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): room_id="!abc:matrix.org", session_id="c0ff33", ) - res = None - try: - yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 404) + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + self.assertDictEqual(res, { + "rooms": {} + }) diff --git a/tests/push/__init__.py b/tests/push/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/push/__init__.py diff --git a/tests/push/test_email.py b/tests/push/test_email.py new file mode 100644 index 0000000000..50ee6910d1 --- /dev/null +++ b/tests/push/test_email.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pkg_resources + +from twisted.internet.defer import Deferred + +from synapse.rest.client.v1 import admin, login, room + +from tests.unittest import HomeserverTestCase + +try: + from synapse.push.mailer import load_jinja2_templates +except Exception: + load_jinja2_templates = None + + +class EmailPusherTests(HomeserverTestCase): + + skip = "No Jinja installed" if not load_jinja2_templates else None + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + config = self.default_config() + config.email_enable_notifs = True + config.start_pushers = True + + config.email_template_dir = os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ) + config.email_notif_template_html = "notif_mail.html" + config.email_notif_template_text = "notif_mail.txt" + config.email_smtp_host = "127.0.0.1" + config.email_smtp_port = 20 + config.require_transport_security = False + config.email_smtp_user = None + config.email_app_name = "Matrix" + config.email_notif_from = "test@example.com" + + hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + + return hs + + def test_sends_email(self): + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name="a@example.com", + pushkey="a@example.com", + lang=None, + data={}, + ) + ) + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Get the stream ordering before it gets sent + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + last_stream_ordering = pushers[0]["last_stream_ordering"] + + # Advance time a bit, so the pusher will register something has happened + self.pump(100) + + # It hasn't succeeded yet, so the stream ordering shouldn't have moved + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + + # One email was attempted to be sent + self.assertEqual(len(self.email_attempts), 1) + + # Make the email succeed + self.email_attempts[0][0].callback(True) + self.pump() + + # One email was attempted to be sent + self.assertEqual(len(self.email_attempts), 1) + + # The stream ordering has increased + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py new file mode 100644 index 0000000000..addc01ab7f --- /dev/null +++ b/tests/push/test_http.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet.defer import Deferred + +from synapse.rest.client.v1 import admin, login, room + +from tests.unittest import HomeserverTestCase + +try: + from synapse.push.mailer import load_jinja2_templates +except Exception: + load_jinja2_templates = None + + +class HTTPPusherTests(HomeserverTestCase): + + skip = "No Jinja installed" if not load_jinja2_templates else None + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + self.push_attempts = [] + + m = Mock() + + def post_json_get_json(url, body): + d = Deferred() + self.push_attempts.append((d, url, body)) + return d + + m.post_json_get_json = post_json_get_json + + config = self.default_config() + config.start_pushers = True + + hs = self.setup_test_homeserver(config=config, simple_http_client=m) + + return hs + + def test_sends_http(self): + """ + The HTTP pusher will send pushes for each message to a HTTP endpoint + when configured to do so. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Get the stream ordering before it gets sent + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + last_stream_ordering = pushers[0]["last_stream_ordering"] + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # It hasn't succeeded yet, so the stream ordering shouldn't have moved + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + + # One push was attempted to be sent -- it'll be the first message + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!" + ) + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # The stream ordering has increased + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + last_stream_ordering = pushers[0]["last_stream_ordering"] + + # Now it'll try and send the second push message, which will be the second one + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][2]["notification"]["content"]["body"], "There!" + ) + + # Make the second push succeed + self.push_attempts[1][0].callback({}) + self.pump() + + # The stream ordering has increased, again + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 41be5d5a1a..1688a741d1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -28,8 +28,8 @@ ROOM_ID = "!room:blue" def dict_equals(self, other): - me = encode_canonical_json(self._event_dict) - them = encode_canonical_json(other._event_dict) + me = encode_canonical_json(self.get_pdu_json()) + them = encode_canonical_json(other.get_pdu_json()) return me == them diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py new file mode 100644 index 0000000000..df3f1cde6e --- /dev/null +++ b/tests/rest/client/test_consent.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from synapse.api.urls import ConsentURIBuilder +from synapse.rest.client.v1 import admin, login, room +from synapse.rest.consent import consent_resource + +from tests import unittest +from tests.server import render + +try: + from synapse.push.mailer import load_jinja2_templates +except Exception: + load_jinja2_templates = None + + +class ConsentResourceTestCase(unittest.HomeserverTestCase): + skip = "No Jinja installed" if not load_jinja2_templates else None + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config.user_consent_version = "1" + config.public_baseurl = "" + config.form_secret = "123abc" + + # Make some temporary templates... + temp_consent_path = self.mktemp() + os.mkdir(temp_consent_path) + os.mkdir(os.path.join(temp_consent_path, 'en')) + config.user_consent_template_dir = os.path.abspath(temp_consent_path) + + with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f: + f.write("{{version}},{{has_consented}}") + + with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f: + f.write("yay!") + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_accept_consent(self): + """ + A user can use the consent form to accept the terms. + """ + uri_builder = ConsentURIBuilder(self.hs.config) + resource = consent_resource.ConsentResource(self.hs) + + # Register a user + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Fetch the consent page, to get the consent version + consent_uri = ( + uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") + + "&u=user" + ) + request, channel = self.make_request( + "GET", consent_uri, access_token=access_token, shorthand=False + ) + render(request, resource, self.reactor) + self.assertEqual(channel.code, 200) + + # Get the version from the body, and whether we've consented + version, consented = channel.result["body"].decode('ascii').split(",") + self.assertEqual(consented, "False") + + # POST to the consent page, saying we've agreed + request, channel = self.make_request( + "POST", + consent_uri + "&v=" + version, + access_token=access_token, + shorthand=False, + ) + render(request, resource, self.reactor) + self.assertEqual(channel.code, 200) + + # Fetch the consent page, to get the consent version -- it should have + # changed + request, channel = self.make_request( + "GET", consent_uri, access_token=access_token, shorthand=False + ) + render(request, resource, self.reactor) + self.assertEqual(channel.code, 200) + + # Get the version from the body, and check that it's the version we + # agreed to, and that we've consented to it. + version, consented = channel.result["body"].decode('ascii').split(",") + self.assertEqual(consented, "True") + self.assertEqual(version, "1") diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 1a553fa3f9..e38eb628a9 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -19,24 +19,17 @@ import json from mock import Mock -from synapse.http.server import JsonResource from synapse.rest.client.v1.admin import register_servlets -from synapse.util import Clock from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) -class UserRegisterTestCase(unittest.TestCase): - def setUp(self): +class UserRegisterTestCase(unittest.HomeserverTestCase): + + servlets = [register_servlets] + + def make_homeserver(self, reactor, clock): - self.clock = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.clock) self.url = "/_matrix/client/r0/admin/register" self.registration_handler = Mock() @@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase): self.secrets = Mock() - self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock - ) + self.hs = self.setup_test_homeserver() self.hs.config.registration_shared_secret = u"shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() - self.resource = JsonResource(self.hs) - register_servlets(self.hs, self.resource) + return self.hs def test_disabled(self): """ @@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase): """ self.hs.config.registration_shared_secret = None - request, channel = make_request("POST", self.url, b'{}') - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, b'{}') + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase): self.hs.get_secrets = Mock(return_value=secrets) - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) self.assertEqual(channel.json_body, {"nonce": "abcd"}) @@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase): Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] # 59 seconds - self.clock.advance(59) + self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('username must be specified', channel.json_body["error"]) # 61 seconds - self.clock.advance(2) + self.reactor.advance(2) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('unrecognised nonce', channel.json_body["error"]) @@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase): """ Only the provided nonce can be used, as it's checked in the MAC. """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase): "mac": want_mac, } ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) @@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase): When the correct nonce is provided, and the right key is provided, the user is registered. """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase): "mac": want_mac, } ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase): """ A valid unrecognised nonce. """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase): "mac": want_mac, } ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('unrecognised nonce', channel.json_body["error"]) @@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase): """ def nonce(): - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) return channel.json_body["nonce"] # @@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase): # Must be present body = json.dumps({}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('nonce must be specified', channel.json_body["error"]) @@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('username must be specified', channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) @@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('password must be specified', channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) @@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase): body = json.dumps( {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 6b7ff813d5..f973eff8cf 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase): ) handlers = Mock(registration_handler=self.registration_handler) - self.clock = MemoryReactorClock() - self.hs_clock = Clock(self.clock) + self.reactor = MemoryReactorClock() + self.hs_clock = Clock(self.reactor) self.hs = self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) @@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase): return_value=(user_id, token) ) - request, channel = make_request(b"POST", url, request_data) - render(request, res, self.clock) + request, channel = make_request(self.reactor, b"POST", url, request_data) + render(request, res, self.reactor) self.assertEquals(channel.result["code"], b"200") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 359f7777ff..a824be9a62 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -23,7 +23,7 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer from synapse.api.constants import Membership -from synapse.rest.client.v1 import room +from synapse.rest.client.v1 import admin, login, room from tests import unittest @@ -799,3 +799,107 @@ class RoomMessageListTestCase(RoomBase): self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 530dc8ba6d..9c401bf300 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -169,7 +169,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - "POST", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8') ) render(request, self.resource, self.hs.get_reactor()) @@ -217,7 +217,9 @@ class RestHelper(object): data = {"membership": membership} - request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8') + ) render(request, self.resource, self.hs.get_reactor()) @@ -228,18 +230,6 @@ class RestHelper(object): self.auth_user_id = temp_id - @defer.inlineCallbacks - def register(self, user_id): - (code, response) = yield self.mock_resource.trigger( - "POST", - "/_matrix/client/r0/register", - json.dumps( - {"user": user_id, "password": "test", "type": "m.login.password"} - ), - ) - self.assertEquals(200, code) - defer.returnValue(response) - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -251,7 +241,9 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8') + ) render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 6a886ee3b8..f42a8efbf4 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,84 +13,47 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.types from synapse.api.errors import Codes -from synapse.http.server import JsonResource from synapse.rest.client.v2_alpha import filter -from synapse.types import UserID -from synapse.util import Clock from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock as MemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) PATH_PREFIX = "/_matrix/client/v2_alpha" -class FilterTestCase(unittest.TestCase): +class FilterTestCase(unittest.HomeserverTestCase): - USER_ID = "@apple:test" + user_id = "@apple:test" + hijack_auth = True EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' - TO_REGISTER = [filter] + servlets = [filter.register_servlets] - def setUp(self): - self.clock = MemoryReactorClock() - self.hs_clock = Clock(self.clock) - - self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock - ) - - self.auth = self.hs.get_auth() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.USER_ID), - "token_id": 1, - "is_guest": False, - } - - def get_user_by_req(request, allow_guest=False, rights="access"): - return synapse.types.create_requester( - UserID.from_string(self.USER_ID), 1, False, None - ) - - self.auth.get_user_by_access_token = get_user_by_access_token - self.auth.get_user_by_req = get_user_by_req - - self.store = self.hs.get_datastore() - self.filtering = self.hs.get_filtering() - self.resource = JsonResource(self.hs) - - for r in self.TO_REGISTER: - r.register_servlets(self.hs, self.resource) + def prepare(self, reactor, clock, hs): + self.filtering = hs.get_filtering() + self.store = hs.get_datastore() def test_add_filter(self): - request, channel = make_request( + request, channel = self.make_request( "POST", - "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - render(request, self.resource, self.clock) + self.render(request) self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) - self.clock.advance(0) + self.pump() self.assertEquals(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): - request, channel = make_request( + request, channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) - render(request, self.resource, self.clock) + self.render(request) self.assertEqual(channel.result["code"], b"403") self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) @@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase): def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - request, channel = make_request( + request, channel = self.make_request( "POST", - "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - render(request, self.resource, self.clock) + self.render(request) self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") @@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase): filter_id = self.filtering.add_user_filter( user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) - self.clock.advance(1) + self.reactor.advance(1) filter_id = filter_id.result - request, channel = make_request( - "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) + request, channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) - render(request, self.resource, self.clock) + self.render(request) self.assertEqual(channel.result["code"], b"200") self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): - request, channel = make_request( - "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) + request, channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) - render(request, self.resource, self.clock) + self.render(request) self.assertEqual(channel.result["code"], b"400") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase): # Currently invalid params do not have an appropriate errcode # in errors.py def test_get_filter_invalid_id(self): - request, channel = make_request( - "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) + request, channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) - render(request, self.resource, self.clock) + self.render(request) self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error def test_get_filter_no_id(self): - request, channel = make_request( - "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) + request, channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) - render(request, self.resource, self.clock) + self.render(request) self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1c128e81f5..753d5c3e80 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -3,22 +3,19 @@ import json from mock import Mock from twisted.python import failure -from twisted.test.proto_helpers import MemoryReactorClock from synapse.api.errors import InteractiveAuthIncompleteError -from synapse.http.server import JsonResource from synapse.rest.client.v2_alpha.register import register_servlets -from synapse.util import Clock from tests import unittest -from tests.server import make_request, render, setup_test_homeserver -class RegisterRestServletTestCase(unittest.TestCase): - def setUp(self): +class RegisterRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [register_servlets] + + def make_homeserver(self, reactor, clock): - self.clock = MemoryReactorClock() - self.hs_clock = Clock(self.clock) self.url = b"/_matrix/client/r0/register" self.appservice = None @@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase): identity_handler=self.identity_handler, login_handler=self.login_handler, ) - self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock - ) + self.hs = self.setup_test_homeserver() self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) @@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] - self.resource = JsonResource(self.hs) - register_servlets(self.hs, self.resource) + return self.hs def test_POST_appservice_registration_valid(self): user_id = "@kermit:muppet" @@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase): self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) request_data = json.dumps({"username": "kermit"}) - request, channel = make_request( + request, channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - render(request, self.resource, self.clock) + self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { @@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase): def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists request_data = json.dumps({"username": "kermit"}) - request, channel = make_request( + request, channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - render(request, self.resource, self.clock) + self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) - request, channel = make_request(b"POST", self.url, request_data) - render(request, self.resource, self.clock) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = make_request(b"POST", self.url, request_data) - render(request, self.resource, self.clock) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid username") @@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase): self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) self.device_handler.check_device_registered = Mock(return_value=device_id) - request, channel = make_request(b"POST", self.url, request_data) - render(request, self.resource, self.clock) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) det_data = { "user_id": user_id, @@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase): self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=("@user:id", "t")) - request, channel = make_request(b"POST", self.url, request_data) - render(request, self.resource, self.clock) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") @@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.config.allow_guest_access = True self.registration_handler.register = Mock(return_value=(user_id, None)) - request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") - render(request, self.resource, self.clock) + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) det_data = { "user_id": user_id, @@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase): def test_POST_disabled_guest_registration(self): self.hs.config.allow_guest_access = False - request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") - render(request, self.resource, self.clock) + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 4c30c5f258..99b716f00a 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -15,9 +15,11 @@ from mock import Mock +from synapse.rest.client.v1 import admin, login, room from synapse.rest.client.v2_alpha import sync from tests import unittest +from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): @@ -65,3 +67,124 @@ class FilterTestCase(unittest.HomeserverTestCase): ["next_batch", "rooms", "account_data", "to_device", "device_lists"] ).issubset(set(channel.json_body.keys())) ) + + +class SyncTypingTests(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + user_id = True + hijack_auth = False + + def test_sync_backwards_typing(self): + """ + If the typing serial goes backwards and the typing handler is then reset + (such as when the master restarts and sets the typing serial to 0), we + do not incorrectly return typing information that had a serial greater + than the now-reset serial. + """ + typing_url = "/rooms/%s/typing/%s?access_token=%s" + sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Start typing. + request, channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.render(request) + self.assertEquals(200, channel.code) + + request, channel = self.make_request( + "GET", "/sync?access_token=%s" % (access_token,) + ) + self.render(request) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Stop typing. + request, channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": false}', + ) + self.render(request) + self.assertEquals(200, channel.code) + + # Start typing. + request, channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.render(request) + self.assertEquals(200, channel.code) + + # Should return immediately + request, channel = self.make_request( + "GET", sync_url % (access_token, next_batch) + ) + self.render(request) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Reset typing serial back to 0, as if the master had. + typing = self.hs.get_typing_handler() + typing._latest_room_serial = 0 + + # Since it checks the state token, we need some state to update to + # invalidate the stream token. + self.helper.send(room, body="There!", tok=other_access_token) + + request, channel = self.make_request( + "GET", sync_url % (access_token, next_batch) + ) + self.render(request) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # This should time out! But it does not, because our stream token is + # ahead, and therefore it's saying the typing (that we've actually + # already seen) is new, since it's got a token above our new, now-reset + # stream token. + request, channel = self.make_request( + "GET", sync_url % (access_token, next_batch) + ) + self.render(request) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Clear the typing information, so that it doesn't think everything is + # in the future. + typing._reset() + + # Now it SHOULD fail as it never completes! + request, channel = self.make_request( + "GET", sync_url % (access_token, next_batch) + ) + self.assertRaises(TimedOutException, self.render, request) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py new file mode 100644 index 0000000000..29579cf091 --- /dev/null +++ b/tests/rest/media/v1/test_url_preview.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from mock import Mock + +from twisted.internet.defer import Deferred + +from synapse.config.repository import MediaStorageProviderConfig +from synapse.util.module_loader import load_module + +from tests import unittest + + +class URLPreviewTests(unittest.HomeserverTestCase): + + hijack_auth = True + user_id = "@test:user" + + def make_homeserver(self, reactor, clock): + + self.storage_path = self.mktemp() + os.mkdir(self.storage_path) + + config = self.default_config() + config.url_preview_enabled = True + config.max_spider_size = 9999999 + config.url_preview_url_blacklist = [] + config.media_store_path = self.storage_path + + provider_config = { + "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + loaded = list(load_module(provider_config)) + [ + MediaStorageProviderConfig(False, False, False) + ] + + config.media_storage_providers = [loaded] + + hs = self.setup_test_homeserver(config=config) + + return hs + + def prepare(self, reactor, clock, hs): + + self.fetches = [] + + def get_file(url, output_stream, max_size): + """ + Returns tuple[int,dict,str,int] of file length, response headers, + absolute URI, and response code. + """ + + def write_to(r): + data, response = r + output_stream.write(data) + return response + + d = Deferred() + d.addCallback(write_to) + self.fetches.append((d, url)) + return d + + client = Mock() + client.get_file = get_file + + self.media_repo = hs.get_media_repository_resource() + preview_url = self.media_repo.children[b'preview_url'] + preview_url.client = client + self.preview_url = preview_url + + def test_cache_returns_correct_type(self): + + request, channel = self.make_request( + "GET", "url_preview?url=matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # We've made one fetch + self.assertEqual(len(self.fetches), 1) + + end_content = ( + b'<html><head>' + b'<meta property="og:title" content="~matrix~" />' + b'<meta property="og:description" content="hi" />' + b'</head></html>' + ) + + self.fetches[0][0].callback( + ( + end_content, + ( + len(end_content), + { + b"Content-Length": [b"%d" % (len(end_content))], + b"Content-Type": [b'text/html; charset="utf8"'], + }, + "https://example.com", + 200, + ), + ) + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + # Check the cache returns the correct response + request, channel = self.make_request( + "GET", "url_preview?url=matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Only one fetch, still, since we'll lean on the cache + self.assertEqual(len(self.fetches), 1) + + # Check the cache response has the same content + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + # Clear the in-memory cache + self.assertIn("matrix.org", self.preview_url._cache) + self.preview_url._cache.pop("matrix.org") + self.assertNotIn("matrix.org", self.preview_url._cache) + + # Check the database cache returns the correct response + request, channel = self.make_request( + "GET", "url_preview?url=matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Only one fetch, still, since we'll lean on the cache + self.assertEqual(len(self.fetches), 1) + + # Check the cache response has the same content + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) diff --git a/tests/server.py b/tests/server.py index 7bee58dff1..7919a1f124 100644 --- a/tests/server.py +++ b/tests/server.py @@ -21,6 +21,12 @@ from synapse.util import Clock from tests.utils import setup_test_homeserver as _sth +class TimedOutException(Exception): + """ + A web query timed out. + """ + + @attr.s class FakeChannel(object): """ @@ -28,6 +34,7 @@ class FakeChannel(object): wire). """ + _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @@ -50,6 +57,8 @@ class FakeChannel(object): self.result["headers"] = headers def write(self, content): + assert isinstance(content, bytes), "Should be bytes! " + repr(content) + if "body" not in self.result: self.result["body"] = b"" @@ -57,6 +66,15 @@ class FakeChannel(object): def registerProducer(self, producer, streaming): self._producer = producer + self.producerStreaming = streaming + + def _produce(): + if self._producer: + self._producer.resumeProducing() + self._reactor.callLater(0.1, _produce) + + if not streaming: + self._reactor.callLater(0.0, _produce) def unregisterProducer(self): if self._producer is None: @@ -98,10 +116,30 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b"", access_token=None, request=SynapseRequest): +def make_request( + reactor, + method, + path, + content=b"", + access_token=None, + request=SynapseRequest, + shorthand=True, +): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. + + Args: + method (bytes/unicode): The HTTP request method ("verb"). + path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. + escaped UTF-8 & spaces and such). + content (bytes or dict): The body of the request. JSON-encoded, if + a dict. + shorthand: Whether to try and be helpful and prefix the given URL + with the usual REST API path, if it doesn't contain it. + + Returns: + A synapse.http.site.SynapseRequest. """ if not isinstance(method, bytes): method = method.encode('ascii') @@ -109,8 +147,8 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe if not isinstance(path, bytes): path = path.encode('ascii') - # Decorate it to be the full path - if not path.startswith(b"/_matrix"): + # Decorate it to be the full path, if we're using shorthand + if shorthand and not path.startswith(b"/_matrix"): path = b"/_matrix/client/r0/" + path path = path.replace(b"//", b"/") @@ -118,14 +156,16 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe content = content.encode('utf8') site = FakeSite() - channel = FakeChannel() + channel = FakeChannel(reactor) req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) if access_token: - req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) + req.requestHeaders.addRawHeader( + b"Authorization", b"Bearer " + access_token.encode('ascii') + ) if content: req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") @@ -151,7 +191,7 @@ def wait_until_result(clock, request, timeout=100): x += 1 if x > timeout: - raise Exception("Timed out waiting for request to finish.") + raise TimedOutException("Timed out waiting for request to finish.") clock.advance(0.1) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index efd85ebe6c..2e073a3afc 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -544,8 +544,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - self.assertTrue(state_d.called) - state_before = state_d.result + state_before = self.successResultOf(state_d) state_after = dict(state_before) if fake_event.state_key is not None: @@ -599,6 +598,103 @@ class LexicographicalTestCase(unittest.TestCase): self.assertEqual(["o", "l", "n", "m", "p"], res) +class SimpleParamStateTestCase(unittest.TestCase): + def setUp(self): + # We build up a simple DAG. + + event_map = {} + + create_event = FakeEvent( + id="CREATE", + sender=ALICE, + type=EventTypes.Create, + state_key="", + content={"creator": ALICE}, + ).to_event([], []) + event_map[create_event.event_id] = create_event + + alice_member = FakeEvent( + id="IMA", + sender=ALICE, + type=EventTypes.Member, + state_key=ALICE, + content=MEMBERSHIP_CONTENT_JOIN, + ).to_event([create_event.event_id], [create_event.event_id]) + event_map[alice_member.event_id] = alice_member + + join_rules = FakeEvent( + id="IJR", + sender=ALICE, + type=EventTypes.JoinRules, + state_key="", + content={"join_rule": JoinRules.PUBLIC}, + ).to_event( + auth_events=[create_event.event_id, alice_member.event_id], + prev_events=[alice_member.event_id], + ) + event_map[join_rules.event_id] = join_rules + + # Bob and Charlie join at the same time, so there is a fork + bob_member = FakeEvent( + id="IMB", + sender=BOB, + type=EventTypes.Member, + state_key=BOB, + content=MEMBERSHIP_CONTENT_JOIN, + ).to_event( + auth_events=[create_event.event_id, join_rules.event_id], + prev_events=[join_rules.event_id], + ) + event_map[bob_member.event_id] = bob_member + + charlie_member = FakeEvent( + id="IMC", + sender=CHARLIE, + type=EventTypes.Member, + state_key=CHARLIE, + content=MEMBERSHIP_CONTENT_JOIN, + ).to_event( + auth_events=[create_event.event_id, join_rules.event_id], + prev_events=[join_rules.event_id], + ) + event_map[charlie_member.event_id] = charlie_member + + self.event_map = event_map + self.create_event = create_event + self.alice_member = alice_member + self.join_rules = join_rules + self.bob_member = bob_member + self.charlie_member = charlie_member + + self.state_at_bob = { + (e.type, e.state_key): e.event_id + for e in [create_event, alice_member, join_rules, bob_member] + } + + self.state_at_charlie = { + (e.type, e.state_key): e.event_id + for e in [create_event, alice_member, join_rules, charlie_member] + } + + self.expected_combined_state = { + (e.type, e.state_key): e.event_id + for e in [create_event, alice_member, join_rules, bob_member, charlie_member] + } + + def test_event_map_none(self): + # Test that we correctly handle passing `None` as the event_map + + state_d = resolve_events_with_store( + [self.state_at_bob, self.state_at_charlie], + event_map=None, + state_res_store=TestStateResolutionStore(self.event_map), + ) + + state = self.successResultOf(state_d) + + self.assert_dict(self.expected_combined_state, state) + + def pairwise(iterable): "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) @@ -657,7 +753,7 @@ class TestStateResolutionStore(object): result.add(event_id) event = self.event_map[event_id] - for aid, _ in event.auth_events: + for aid in event.auth_event_ids(): stack.append(aid) return list(result) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 8f0aaece40..b83f7336d3 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -45,6 +45,21 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) @defer.inlineCallbacks + def test_reupload_key(self): + now = 1470174257070 + json = {"key": "value"} + + yield self.store.store_device("user", "device", None) + + changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + self.assertTrue(changed) + + # If we try to upload the same key then we should be told nothing + # changed + changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + self.assertFalse(changed) + + @defer.inlineCallbacks def test_get_key_with_device_name(self): now = 1470174257070 json = {"key": "value"} diff --git a/tests/test_federation.py b/tests/test_federation.py index 952a0a7b51..e1a34ccffd 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -112,7 +112,7 @@ class MessageAcceptTests(unittest.TestCase): "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", - "content": "hewwo?", + "content": {"body": "hewwo?"}, "auth_events": [], "prev_events": [("two:test.serv", {}), (most_recent, {})], } diff --git a/tests/test_mau.py b/tests/test_mau.py index bdbacb8448..0afdeb0818 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -21,30 +21,20 @@ from mock import Mock, NonCallableMock from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError -from synapse.http.server import JsonResource from synapse.rest.client.v2_alpha import register, sync -from synapse.util import Clock from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) -class TestMauLimit(unittest.TestCase): - def setUp(self): - self.reactor = ThreadedMemoryReactorClock() - self.clock = Clock(self.reactor) +class TestMauLimit(unittest.HomeserverTestCase): - self.hs = setup_test_homeserver( - self.addCleanup, + servlets = [register.register_servlets, sync.register_servlets] + + def make_homeserver(self, reactor, clock): + + self.hs = self.setup_test_homeserver( "red", http_client=None, - clock=self.clock, - reactor=self.reactor, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) @@ -63,10 +53,7 @@ class TestMauLimit(unittest.TestCase): self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_room_name = "Test Server Notice Room" - - self.resource = JsonResource(self.hs) - register.register_servlets(self.hs, self.resource) - sync.register_servlets(self.hs, self.resource) + return self.hs def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated @@ -193,8 +180,8 @@ class TestMauLimit(unittest.TestCase): } ) - request, channel = make_request("POST", "/register", request_data) - render(request, self.resource, self.reactor) + request, channel = self.make_request("POST", "/register", request_data) + self.render(request) if channel.code != 200: raise HttpResponseException( @@ -206,10 +193,10 @@ class TestMauLimit(unittest.TestCase): return access_token def do_sync_for_user(self, token): - request, channel = make_request( - "GET", "/sync", access_token=token.encode('ascii') + request, channel = self.make_request( + "GET", "/sync", access_token=token ) - render(request, self.resource, self.reactor) + self.render(request) if channel.code != 200: raise HttpResponseException( diff --git a/tests/test_server.py b/tests/test_server.py index 4045fdadc3..f0e6291b7e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -57,7 +57,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback ) - request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") + request, channel = make_request( + self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" + ) render(request, res, self.reactor) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) @@ -75,7 +77,7 @@ class JsonResourceTests(unittest.TestCase): res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/_matrix/foo") + request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'500') @@ -98,7 +100,7 @@ class JsonResourceTests(unittest.TestCase): res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/_matrix/foo") + request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'500') @@ -115,7 +117,7 @@ class JsonResourceTests(unittest.TestCase): res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/_matrix/foo") + request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'403') @@ -136,7 +138,7 @@ class JsonResourceTests(unittest.TestCase): res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/_matrix/foobar") + request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'400') diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py new file mode 100644 index 0000000000..9ecc3ef14f --- /dev/null +++ b/tests/test_terms_auth.py @@ -0,0 +1,123 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import six +from mock import Mock + +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.util import Clock + +from tests import unittest + + +class TermsTestCase(unittest.HomeserverTestCase): + servlets = [register_servlets] + + def prepare(self, reactor, clock, hs): + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = "/_matrix/client/r0/register" + self.registration_handler = Mock() + self.auth_handler = Mock() + self.device_handler = Mock() + hs.config.enable_registration = True + hs.config.registrations_require_3pid = [] + hs.config.auto_join_rooms = [] + hs.config.enable_registration_captcha = False + + def test_ui_auth(self): + self.hs.config.user_consent_at_registration = True + self.hs.config.user_consent_policy_name = "My Cool Privacy Policy" + self.hs.config.public_baseurl = "https://example.org" + self.hs.config.user_consent_version = "1.0" + + # Do a UI auth request + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + self.assertTrue(channel.json_body is not None) + self.assertIsInstance(channel.json_body["session"], six.text_type) + + self.assertIsInstance(channel.json_body["flows"], list) + for flow in channel.json_body["flows"]: + self.assertIsInstance(flow["stages"], list) + self.assertTrue(len(flow["stages"]) > 0) + self.assertEquals(flow["stages"][-1], "m.login.terms") + + expected_params = { + "m.login.terms": { + "policies": { + "privacy_policy": { + "en": { + "name": "My Cool Privacy Policy", + "url": "https://example.org/_matrix/consent?v=1.0", + }, + "version": "1.0" + }, + }, + }, + } + self.assertIsInstance(channel.json_body["params"], dict) + self.assertDictContainsSubset(channel.json_body["params"], expected_params) + + # We have to complete the dummy auth stage before completing the terms stage + request_data = json.dumps( + { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.dummy", + }, + } + ) + + self.registration_handler.check_username = Mock(return_value=True) + + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) + + # We don't bother checking that the response is correct - we'll leave that to + # other tests. We just want to make sure we're on the right path. + self.assertEquals(channel.result["code"], b"401", channel.result) + + # Finish the UI auth for terms + request_data = json.dumps( + { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.terms", + }, + } + ) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) + + # We're interested in getting a response that looks like a successful + # registration, not so much that the details are exactly what we want. + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertTrue(channel.json_body is not None) + self.assertIsInstance(channel.json_body["user_id"], six.text_type) + self.assertIsInstance(channel.json_body["access_token"], six.text_type) + self.assertIsInstance(channel.json_body["device_id"], six.text_type) diff --git a/tests/unittest.py b/tests/unittest.py index a59291cc60..a9ce57da9a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -146,6 +146,13 @@ def DEBUG(target): return target +def INFO(target): + """A decorator to set the .loglevel attribute to logging.INFO. + Can apply to either a TestCase or an individual test method.""" + target.loglevel = logging.INFO + return target + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -182,11 +189,11 @@ class HomeserverTestCase(TestCase): for servlet in self.servlets: servlet(self.hs, self.resource) - if hasattr(self, "user_id"): - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.v1.utils import RestHelper - self.helper = RestHelper(self.hs, self.resource, self.user_id) + self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) + if hasattr(self, "user_id"): if self.hijack_auth: def get_user_by_access_token(token=None, allow_guest=False): @@ -251,7 +258,13 @@ class HomeserverTestCase(TestCase): """ def make_request( - self, method, path, content=b"", access_token=None, request=SynapseRequest + self, + method, + path, + content=b"", + access_token=None, + request=SynapseRequest, + shorthand=True, ): """ Create a SynapseRequest at the path using the method and containing the @@ -263,6 +276,8 @@ class HomeserverTestCase(TestCase): escaped UTF-8 & spaces and such). content (bytes or dict): The body of the request. JSON-encoded, if a dict. + shorthand: Whether to try and be helpful and prefix the given URL + with the usual REST API path, if it doesn't contain it. Returns: A synapse.http.site.SynapseRequest. @@ -270,7 +285,9 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content, access_token, request) + return make_request( + self.reactor, method, path, content, access_token, request, shorthand + ) def render(self, request): """ @@ -373,5 +390,5 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 200) - access_token = channel.json_body["access_token"].encode('ascii') + access_token = channel.json_body["access_token"] return access_token diff --git a/tests/utils.py b/tests/utils.py index 565bb60d08..67ab916f30 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -123,6 +123,8 @@ def default_config(name): config.user_directory_search_all_users = False config.user_consent_server_notice_content = None config.block_events_without_consent_error = None + config.user_consent_at_registration = False + config.user_consent_policy_name = "Privacy Policy" config.media_storage_providers = [] config.autocreate_auto_join_rooms = True config.auto_join_rooms = [] |