summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/v1/test_admin.py33
-rw-r--r--tests/rest/client/v1/test_events.py100
-rw-r--r--tests/rest/client/v1/test_register.py89
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py78
-rw-r--r--tests/rest/client/v2_alpha/test_register.py86
-rw-r--r--tests/rest/media/v1/test_url_preview.py424
6 files changed, 500 insertions, 310 deletions
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py

index e38eb628a9..407bf0ac4c 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py
@@ -19,6 +19,7 @@ import json from mock import Mock +from synapse.api.constants import UserTypes from synapse.rest.client.v1.admin import register_servlets from tests import unittest @@ -147,7 +148,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac.update( + nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support" + ) want_mac = want_mac.hexdigest() body = json.dumps( @@ -156,6 +159,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "username": "bob", "password": "abc123", "admin": True, + "user_type": UserTypes.SUPPORT, "mac": want_mac, } ) @@ -174,7 +178,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac.update( + nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin" + ) want_mac = want_mac.hexdigest() body = json.dumps( @@ -202,8 +208,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): def test_missing_parts(self): """ Synapse will complain if you don't give nonce, username, password, and - mac. Admin is optional. Additional checks are done for length and - type. + mac. Admin and user_types are optional. Additional checks are done for length + and type. """ def nonce(): @@ -260,7 +266,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual('Invalid username', channel.json_body["error"]) # - # Username checks + # Password checks # # Must be present @@ -296,3 +302,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) + + # + # user_type check + # + + # Invalid user_type + body = json.dumps({ + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid"} + ) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid user type', channel.json_body["error"]) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 956f7fc4c4..483bebc832 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py
@@ -16,64 +16,49 @@ """ Tests REST events for /events paths.""" from mock import Mock, NonCallableMock -from six import PY3 -from twisted.internet import defer +from synapse.rest.client.v1 import admin, events, login, room -from ....utils import MockHttpResource, setup_test_homeserver -from .utils import RestTestCase +from tests import unittest -PATH_PREFIX = "/_matrix/client/api/v1" - -class EventStreamPermissionsTestCase(RestTestCase): +class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): """ Tests event streaming (GET /events). """ - if PY3: - skip = "Skip on Py3 until ported to use not V1 only register." + servlets = [ + events.register_servlets, + room.register_servlets, + admin.register_servlets, + login.register_servlets, + ] - @defer.inlineCallbacks - def setUp(self): - import synapse.rest.client.v1.events - import synapse.rest.client.v1_only.register - import synapse.rest.client.v1.room + def make_homeserver(self, reactor, clock): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + config = self.default_config() + config.enable_registration_captcha = False + config.enable_registration = True + config.auto_join_rooms = [] - hs = yield setup_test_homeserver( - self.addCleanup, - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + hs = self.setup_test_homeserver( + config=config, ratelimiter=NonCallableMock(spec_set=["send_message"]) ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) - hs.config.enable_registration_captcha = False - hs.config.enable_registration = True - hs.config.auto_join_rooms = [] hs.get_handlers().federation_handler = Mock() - synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource) - synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + return hs + + def prepare(self, hs, reactor, clock): # register an account - self.user_id = "sid1" - response = yield self.register(self.user_id) - self.token = response["access_token"] - self.user_id = response["user_id"] + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") # register a 2nd account - self.other_user = "other1" - response = yield self.register(self.other_user) - self.other_token = response["access_token"] - self.other_user = response["user_id"] + self.other_user = self.register_user("other2", "pass") + self.other_token = self.login(self.other_user, "pass") - def tearDown(self): - pass - - @defer.inlineCallbacks def test_stream_basic_permissions(self): # invalid token, expect 401 # note: this is in violation of the original v1 spec, which expected @@ -81,34 +66,37 @@ class EventStreamPermissionsTestCase(RestTestCase): # implementation is now part of the r0 implementation, the newer # behaviour is used instead to be consistent with the r0 spec. # see issue #2602 - (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s" % ("invalid" + self.token,) + request, channel = self.make_request( + "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.assertEquals(401, code, msg=str(response)) + self.render(request) + self.assertEquals(channel.code, 401, msg=channel.result) # valid token, expect content - (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s&timeout=0" % (self.token,) + request, channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(200, code, msg=str(response)) - self.assertTrue("chunk" in response) - self.assertTrue("start" in response) - self.assertTrue("end" in response) + self.render(request) + self.assertEquals(channel.code, 200, msg=channel.result) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("start" in channel.json_body) + self.assertTrue("end" in channel.json_body) - @defer.inlineCallbacks def test_stream_room_permissions(self): - room_id = yield self.create_room_as(self.other_user, tok=self.other_token) - yield self.send(room_id, tok=self.other_token) + room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) + self.helper.send(room_id, tok=self.other_token) # invited to room (expect no content for room) - yield self.invite( + self.helper.invite( room_id, src=self.other_user, targ=self.user_id, tok=self.other_token ) - (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s&timeout=0" % (self.token,) + # valid token, expect content + request, channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(200, code, msg=str(response)) + self.render(request) + self.assertEquals(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down self.assertEquals( @@ -116,7 +104,7 @@ class EventStreamPermissionsTestCase(RestTestCase): len( [ c - for c in response["chunk"] + for c in channel.json_body["chunk"] if not ( c.get("type") == "m.presence" and c["content"].get("user_id") == self.user_id @@ -126,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase): ) # joined room (expect all content for room) - yield self.join(room=room_id, user=self.user_id, tok=self.token) + self.helper.join(room=room_id, user=self.user_id, tok=self.token) # left to room (expect no content for room) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py deleted file mode 100644
index f973eff8cf..0000000000 --- a/tests/rest/client/v1/test_register.py +++ /dev/null
@@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from mock import Mock -from six import PY3 - -from twisted.test.proto_helpers import MemoryReactorClock - -from synapse.http.server import JsonResource -from synapse.rest.client.v1_only.register import register_servlets -from synapse.util import Clock - -from tests import unittest -from tests.server import make_request, render, setup_test_homeserver - - -class CreateUserServletTestCase(unittest.TestCase): - """ - Tests for CreateUserRestServlet. - """ - - if PY3: - skip = "Not ported to Python 3." - - def setUp(self): - self.registration_handler = Mock() - - self.appservice = Mock(sender="@as:test") - self.datastore = Mock( - get_app_service_by_token=Mock(return_value=self.appservice) - ) - - handlers = Mock(registration_handler=self.registration_handler) - self.reactor = MemoryReactorClock() - self.hs_clock = Clock(self.reactor) - - self.hs = self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor - ) - self.hs.get_datastore = Mock(return_value=self.datastore) - self.hs.get_handlers = Mock(return_value=handlers) - - def test_POST_createuser_with_valid_user(self): - - res = JsonResource(self.hs) - register_servlets(self.hs, res) - - request_data = json.dumps( - { - "localpart": "someone", - "displayname": "someone interesting", - "duration_seconds": 200, - } - ) - - url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service' - - user_id = "@someone:interesting" - token = "my token" - - self.registration_handler.get_or_create_user = Mock( - return_value=(user_id, token) - ) - - request, channel = make_request(self.reactor, b"POST", url, request_data) - render(request, res, self.reactor) - - self.assertEquals(channel.result["code"], b"200") - - det_data = { - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - } - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py new file mode 100644
index 0000000000..d3d43970fb --- /dev/null +++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.constants import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS +from synapse.rest.client.v1 import admin, login +from synapse.rest.client.v2_alpha import capabilities + +from tests import unittest + + +class CapabilitiesTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + capabilities.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.url = b"/_matrix/client/r0/capabilities" + hs = self.setup_test_homeserver() + self.store = hs.get_datastore() + return hs + + def test_check_auth_required(self): + request, channel = self.make_request("GET", self.url) + self.render(request) + + self.assertEqual(channel.code, 401) + + def test_get_room_version_capabilities(self): + self.register_user("user", "pass") + access_token = self.login("user", "pass") + + request, channel = self.make_request("GET", self.url, access_token=access_token) + self.render(request) + capabilities = channel.json_body['capabilities'] + + self.assertEqual(channel.code, 200) + for room_version in capabilities['m.room_versions']['available'].keys(): + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) + self.assertEqual( + DEFAULT_ROOM_VERSION, capabilities['m.room_versions']['default'] + ) + + def test_get_change_password_capabilities(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.login(user, password) + + request, channel = self.make_request("GET", self.url, access_token=access_token) + self.render(request) + capabilities = channel.json_body['capabilities'] + + self.assertEqual(channel.code, 200) + + # Test case where password is handled outside of Synapse + self.assertTrue(capabilities['m.change_password']['enabled']) + self.get_success(self.store.user_set_password_hash(user, None)) + request, channel = self.make_request("GET", self.url, access_token=access_token) + self.render(request) + capabilities = channel.json_body['capabilities'] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities['m.change_password']['enabled']) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 753d5c3e80..906b348d3e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,10 +1,7 @@ import json -from mock import Mock - -from twisted.python import failure - -from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.constants import LoginType +from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha.register import register_servlets from tests import unittest @@ -18,50 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/register" - self.appservice = None - self.auth = Mock( - get_appservice_by_req=Mock(side_effect=lambda x: self.appservice) - ) - - self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) - self.auth_handler = Mock( - check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None), - ) - self.registration_handler = Mock() - self.identity_handler = Mock() - self.login_handler = Mock() - self.device_handler = Mock() - self.device_handler.check_device_registered = Mock(return_value="FAKE") - - self.datastore = Mock(return_value=Mock()) - self.datastore.get_current_state_deltas = Mock(return_value=[]) - - # do the dance to hook it up to the hs global - self.handlers = Mock( - registration_handler=self.registration_handler, - identity_handler=self.identity_handler, - login_handler=self.login_handler, - ) self.hs = self.setup_test_homeserver() - self.hs.get_auth = Mock(return_value=self.auth) - self.hs.get_handlers = Mock(return_value=self.handlers) - self.hs.get_auth_handler = Mock(return_value=self.auth_handler) - self.hs.get_device_handler = Mock(return_value=self.device_handler) - self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False return self.hs def test_POST_appservice_registration_valid(self): - user_id = "@kermit:muppet" - token = "kermits_access_token" - self.appservice = {"id": "1234"} - self.registration_handler.appservice_register = Mock(return_value=user_id) - self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) - request_data = json.dumps({"username": "kermit"}) + user_id = "@as_user_kermit:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, self.hs.config.hostname, + id="1234", + namespaces={ + "users": [{"regex": r"@as_user.*", "exclusive": True}], + }, + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) request, channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -71,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { "user_id": user_id, - "access_token": token, "home_server": self.hs.hostname, } self.assertDictContainsSubset(det_data, channel.json_body) @@ -103,39 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self): - user_id = "@kermit:muppet" - token = "kermits_access_token" + user_id = "@kermit:test" device_id = "frogfone" - request_data = json.dumps( - {"username": "kermit", "password": "monkey", "device_id": device_id} - ) - self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) - self.device_handler.check_device_registered = Mock(return_value=device_id) - + params = { + "username": "kermit", + "password": "monkey", + "device_id": device_id, + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) det_data = { "user_id": user_id, - "access_token": token, "home_server": self.hs.hostname, "device_id": device_id, } self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) - self.auth_handler.get_login_tuple_for_user_id( - user_id, device_id=device_id, initial_device_display_name=None - ) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) - self.registration_handler.check_username = Mock(return_value=True) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - self.registration_handler.register = Mock(return_value=("@user:id", "t")) request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) @@ -144,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["error"], "Registration has been disabled") def test_POST_guest_registration(self): - user_id = "a@b" self.hs.config.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True - self.registration_handler.register = Mock(return_value=(user_id, None)) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) det_data = { - "user_id": user_id, "home_server": self.hs.hostname, "device_id": "guest_device", } diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index c62f71b44a..650ce95a6f 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py
@@ -15,21 +15,55 @@ import os -from mock import Mock +import attr +from netaddr import IPSet -from twisted.internet.defer import Deferred +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.error import DNSLookupError +from twisted.python.failure import Failure +from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.web._newclient import ResponseDone from synapse.config.repository import MediaStorageProviderConfig -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.module_loader import load_module from tests import unittest +from tests.server import FakeTransport + + +@attr.s +class FakeResponse(object): + version = attr.ib() + code = attr.ib() + phrase = attr.ib() + headers = attr.ib() + body = attr.ib() + absoluteURI = attr.ib() + + @property + def request(self): + @attr.s + class FakeTransport(object): + absoluteURI = self.absoluteURI + + return FakeTransport() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" + end_content = ( + b'<html><head>' + b'<meta property="og:title" content="~matrix~" />' + b'<meta property="og:description" content="hi" />' + b'</head></html>' + ) def make_homeserver(self, reactor, clock): @@ -39,6 +73,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): config = self.default_config() config.url_preview_enabled = True config.max_spider_size = 9999999 + config.url_preview_ip_range_blacklist = IPSet( + ( + "192.168.1.1", + "1.0.0.0/8", + "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:800::/21", + ) + ) + config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) config.url_preview_url_blacklist = [] config.media_store_path = self.storage_path @@ -62,63 +105,50 @@ class URLPreviewTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.fetches = [] + self.media_repo = hs.get_media_repository_resource() + self.preview_url = self.media_repo.children[b'preview_url'] - def get_file(url, output_stream, max_size): - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + self.lookups = {} - def write_to(r): - data, response = r - output_stream.write(data) - return response + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): - d = Deferred() - d.addCallback(write_to) - self.fetches.append((d, url)) - return make_deferred_yieldable(d) + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") - client = Mock() - client.get_file = get_file + for i in self.lookups[hostName]: + resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) + resolutionReceiver.resolutionComplete() + return resolutionReceiver - self.media_repo = hs.get_media_repository_resource() - preview_url = self.media_repo.children[b'preview_url'] - preview_url.client = client - self.preview_url = preview_url + self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # We've made one fetch - self.assertEqual(len(self.fetches), 1) - - end_content = ( - b'<html><head>' - b'<meta property="og:title" content="~matrix~" />' - b'<meta property="og:description" content="hi" />' - b'</head></html>' - ) - - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="utf8"'], - }, - "https://example.com", - 200, - ), - ) + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content ) self.pump() @@ -129,14 +159,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check the cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # Only one fetch, still, since we'll lean on the cache - self.assertEqual(len(self.fetches), 1) - # Check the cache response has the same content self.assertEqual(channel.code, 200) self.assertEqual( @@ -144,20 +171,17 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Clear the in-memory cache - self.assertIn("matrix.org", self.preview_url._cache) - self.preview_url._cache.pop("matrix.org") - self.assertNotIn("matrix.org", self.preview_url._cache) + self.assertIn("http://matrix.org", self.preview_url._cache) + self.preview_url._cache.pop("http://matrix.org") + self.assertNotIn("http://matrix.org", self.preview_url._cache) # Check the database cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # Only one fetch, still, since we'll lean on the cache - self.assertEqual(len(self.fetches), 1) - # Check the cache response has the same content self.assertEqual(channel.code, 200) self.assertEqual( @@ -165,78 +189,282 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def test_non_ascii_preview_httpequiv(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + + end_content = ( + b'<html><head>' + b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>' + b'<meta property="og:title" content="\xe4\xea\xe0" />' + b'<meta property="og:description" content="hi" />' + b'</head></html>' + ) request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # We've made one fetch - self.assertEqual(len(self.fetches), 1) + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + + def test_non_ascii_preview_content_type(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( b'<html><head>' - b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>' b'<meta property="og:title" content="\xe4\xea\xe0" />' b'<meta property="og:description" content="hi" />' b'</head></html>' ) - self.fetches[0][0].callback( + request, channel = self.make_request( + "GET", "url_preview?url=http://matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - # This charset=utf-8 should be ignored, because the - # document has a meta tag overriding it. - b"Content-Type": [b'text/html; charset="utf8"'], - }, - "https://example.com", - 200, - ), + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" ) + % (len(end_content),) + + end_content ) self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") - def test_non_ascii_preview_content_type(self): + def test_ipaddr(self): + """ + IP addresses can be previewed directly. + """ + self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() - # We've made one fetch - self.assertEqual(len(self.fetches), 1) + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) - end_content = ( - b'<html><head>' - b'<meta property="og:title" content="\xe4\xea\xe0" />' - b'<meta property="og:description" content="hi" />' - b'</head></html>' + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="windows-1251"'], - }, - "https://example.com", - 200, - ), - ) + def test_blacklisted_ip_specific(self): + """ + Blacklisted IP addresses, found via DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_range(self): + """ + Blacklisted IP ranges, IPs found over DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_specific_direct(self): + """ + Blacklisted IP addresses, accessed directly, are not spidered. + """ + request, channel = self.make_request( + "GET", "url_preview?url=http://192.168.1.1", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_range_direct(self): + """ + Blacklisted IP ranges, accessed directly, are not spidered. + """ + request, channel = self.make_request( + "GET", "url_preview?url=http://1.1.1.2", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ip_range_whitelisted_ip(self): + """ + Blacklisted but then subsequently whitelisted IP addresses can be + spidered. + """ + self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content ) self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + def test_blacklisted_ip_with_external_ip(self): + """ + If a hostname resolves a blacklisted IP, even if there's a + non-blacklisted one, it will be rejected. + """ + # Hardcode the URL resolving to the IP we want. + self.lookups[u"example.com"] = [ + (IPv4Address, "1.1.1.2"), + (IPv4Address, "8.8.8.8"), + ] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ipv6_specific(self): + """ + Blacklisted IP addresses, found via DNS, are not spidered. + """ + self.lookups["example.com"] = [ + (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + ] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ipv6_range(self): + """ + Blacklisted IP ranges, IPs found over DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + )