diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/http/test_fedclient.py | 4 | ||||
-rw-r--r-- | tests/server.py | 8 | ||||
-rw-r--r-- | tests/storage/test_client_ips.py | 202 | ||||
-rw-r--r-- | tests/storage/test_state.py | 39 | ||||
-rw-r--r-- | tests/test_federation.py | 28 | ||||
-rw-r--r-- | tests/test_server.py | 74 | ||||
-rw-r--r-- | tests/unittest.py | 12 | ||||
-rw-r--r-- | tests/util/test_expiring_cache.py | 1 | ||||
-rw-r--r-- | tests/util/test_logcontext.py | 5 | ||||
-rw-r--r-- | tests/utils.py | 27 |
10 files changed, 339 insertions, 61 deletions
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 66c09f63b6..f3cb1423f0 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -54,7 +54,7 @@ class FederationClientTests(HomeserverTestCase): def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a - ConnectingCancelledError. + ConnectingCancelledError or TimeoutError. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) @@ -76,7 +76,7 @@ class FederationClientTests(HomeserverTestCase): self.reactor.advance(10.5) f = self.failureResultOf(d) - self.assertIsInstance(f.value, ConnectingCancelledError) + self.assertIsInstance(f.value, (ConnectingCancelledError, TimeoutError)) def test_client_connect_no_response(self): """ diff --git a/tests/server.py b/tests/server.py index ccea3baa55..7bee58dff1 100644 --- a/tests/server.py +++ b/tests/server.py @@ -98,7 +98,7 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b"", access_token=None): +def make_request(method, path, content=b"", access_token=None, request=SynapseRequest): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. @@ -120,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None): site = FakeSite() channel = FakeChannel() - req = SynapseRequest(site, channel) + req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) if access_token: req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) - req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") + if content: + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + req.requestReceived(method, path, b"1.1") return req, channel diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c9b02a062b..2ffbb9f14f 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +13,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import hashlib +import hmac +import json + from mock import Mock from twisted.internet import defer -import tests.unittest -import tests.utils +from synapse.http.site import XForwardedForRequest +from synapse.rest.client.v1 import admin, login + +from tests import unittest -class ClientIpStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.clock = None # type: tests.utils.MockClock +class ClientIpStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs - @defer.inlineCallbacks - def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) + def prepare(self, hs, reactor, clock): self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - @defer.inlineCallbacks def test_insert_new_client_ip(self): - self.clock.now = 12345678 + self.reactor.advance(12345678) + user_id = "@user:id" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) r = result[(user_id, "device_id")] self.assertDictContainsSubset( @@ -55,18 +66,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): r, ) - @defer.inlineCallbacks def test_disabled_monthly_active_user(self): self.hs.config.limit_usage_by_mau = False self.hs.config.max_mau_value = 50 user_id = "@user:server" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_full(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 @@ -76,38 +87,159 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) - @defer.inlineCallbacks def test_updating_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - yield self.store.register(user_id=user_id, token="123", password_hash=None) + self.get_success( + self.store.register(user_id=user_id, token="123", password_hash=None) + ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + + +class ClientIpAuthTestCase(unittest.HomeserverTestCase): + + servlets = [admin.register_servlets, login.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): + self.hs.config.registration_shared_secret = u"shared" + self.store = self.hs.get_datastore() + + # Create the user + request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + ) + self.render(request) + + self.assertEqual(channel.code, 200) + self.user_id = channel.json_body["user_id"] + + def test_request_with_xforwarded(self): + """ + The IP in X-Forwarded-For is entered into the client IPs table. + """ + self._runtest( + {b"X-Forwarded-For": b"127.9.0.1"}, + "127.9.0.1", + {"request": XForwardedForRequest}, + ) + + def test_request_from_getPeer(self): + """ + The IP returned by getPeer is entered into the client IPs table, if + there's no X-Forwarded-For header. + """ + self._runtest({}, "127.0.0.1", {}) + + def _runtest(self, headers, expected_ip, make_request_args): + device_id = "bleb" + + body = json.dumps( + { + "type": "m.login.password", + "user": "bob", + "password": "abc123", + "device_id": device_id, + } + ) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/login", body.encode('utf8'), **make_request_args + ) + self.render(request) + self.assertEqual(channel.code, 200) + access_token = channel.json_body["access_token"].encode('ascii') + + # Advance to a known time + self.reactor.advance(123456 - self.reactor.seconds()) + + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/admin/users/" + self.user_id, + body.encode('utf8'), + access_token=access_token, + **make_request_args + ) + request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") + + # Add the optional headers + for h, v in headers.items(): + request.requestHeaders.addRawHeader(h, v) + self.render(request) + + # Advance so the save loop occurs + self.reactor.advance(100) + + result = self.get_success( + self.store.get_last_client_ip_by_device(self.user_id, device_id) + ) + r = result[(self.user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": self.user_id, + "device_id": device_id, + "ip": expected_ip, + "user_agent": "Mozzila pizza", + "last_seen": 123456100, + }, + r, + ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b910965932..b9c5b39d59 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -75,6 +75,45 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(len(s1), len(s2)) @defer.inlineCallbacks + def test_get_state_groups_ids(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_map = list(state_group_map.values())[0] + self.assertDictEqual( + state_map, + { + (EventTypes.Create, ''): e1.event_id, + (EventTypes.Name, ''): e2.event_id, + }, + ) + + @defer.inlineCallbacks + def test_get_state_groups(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups( + self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_list = list(state_group_map.values())[0] + + self.assertEqual( + {ev.event_id for ev in state_list}, + {e1.event_id, e2.event_id}, + ) + + @defer.inlineCallbacks def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever diff --git a/tests/test_federation.py b/tests/test_federation.py index 2540604fcc..ff55c7a627 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -6,6 +6,7 @@ from twisted.internet.defer import maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.types import Requester, UserID from synapse.util import Clock +from synapse.util.logcontext import LoggingContext from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver @@ -117,9 +118,10 @@ class MessageAcceptTests(unittest.TestCase): } ) - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ) + with LoggingContext(request="lying_event"): + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) # Step the reactor, so the database fetches come back self.reactor.advance(1) @@ -209,11 +211,12 @@ class MessageAcceptTests(unittest.TestCase): } ) - d = self.handler.on_receive_pdu( - "test.serv", good_event, sent_to_us_directly=True - ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) + with LoggingContext(request="good_event"): + d = self.handler.on_receive_pdu( + "test.serv", good_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) bad_event = FrozenEvent( { @@ -230,10 +233,11 @@ class MessageAcceptTests(unittest.TestCase): } ) - d = self.handler.on_receive_pdu( - "test.serv", bad_event, sent_to_us_directly=True - ) - self.reactor.advance(1) + with LoggingContext(request="bad_event"): + d = self.handler.on_receive_pdu( + "test.serv", bad_event, sent_to_us_directly=True + ) + self.reactor.advance(1) extrem = maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id diff --git a/tests/test_server.py b/tests/test_server.py index ef74544e93..4045fdadc3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,14 +1,35 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging import re +from six import StringIO + from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.python.failure import Failure +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource +from synapse.http.site import SynapseSite, logger from synapse.util import Clock from tests import unittest -from tests.server import make_request, render, setup_test_homeserver +from tests.server import FakeTransport, make_request, render, setup_test_homeserver class JsonResourceTests(unittest.TestCase): @@ -121,3 +142,52 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + +class SiteTestCase(unittest.HomeserverTestCase): + def test_lose_connection(self): + """ + We log the URI correctly redacted when we lose the connection. + """ + + class HangingResource(Resource): + """ + A Resource that strategically hangs, as if it were processing an + answer. + """ + + def render(self, request): + return NOT_DONE_YET + + # Set up a logging handler that we can inspect afterwards + output = StringIO() + handler = logging.StreamHandler(output) + logger.addHandler(handler) + old_level = logger.level + logger.setLevel(10) + self.addCleanup(logger.setLevel, old_level) + self.addCleanup(logger.removeHandler, handler) + + # Make a resource and a Site, the resource will hang and allow us to + # time out the request while it's 'processing' + base_resource = Resource() + base_resource.putChild(b'', HangingResource()) + site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") + + server = site.buildProtocol(None) + client = AccumulatingProtocol() + client.makeConnection(FakeTransport(server, self.reactor)) + server.makeConnection(FakeTransport(client, self.reactor)) + + # Send a request with an access token that will get redacted + server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") + self.pump() + + # Lose the connection + e = Failure(Exception("Failed123")) + server.connectionLost(e) + handler.flush() + + # Our access token is redacted and the failure reason is logged. + self.assertIn("/?access_token=<redacted>", output.getvalue()) + self.assertIn("Failed123", output.getvalue()) diff --git a/tests/unittest.py b/tests/unittest.py index a3d39920db..043710afaf 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -26,6 +26,7 @@ from twisted.internet.defer import Deferred from twisted.trial import unittest from synapse.http.server import JsonResource +from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.logcontext import LoggingContextFilter @@ -120,7 +121,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))(e.message + " for '.%s'" % key) + raise (type(e))(str(e) + " for '.%s'" % key) def assert_dict(self, required, actual): """Does a partial assert of a dict. @@ -219,7 +220,8 @@ class HomeserverTestCase(TestCase): Function to be overridden in subclasses. """ - raise NotImplementedError() + hs = self.setup_test_homeserver() + return hs def prepare(self, reactor, clock, homeserver): """ @@ -236,7 +238,9 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - def make_request(self, method, path, content=b""): + def make_request( + self, method, path, content=b"", access_token=None, request=SynapseRequest + ): """ Create a SynapseRequest at the path using the method and containing the given content. @@ -254,7 +258,7 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content) + return make_request(method, path, content, access_token, request) def render(self, request): """ diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 5cbada4eda..50bc7702d2 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -65,7 +65,6 @@ class ExpiringCacheTestCase(unittest.TestCase): def test_time_eviction(self): clock = MockClock() cache = ExpiringCache("test", clock, expiry_ms=1000) - cache.start() cache["key"] = 1 clock.advance_time(0.5) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 4633db77b3..8adaee3c8d 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -159,6 +159,11 @@ class LoggingContextTestCase(unittest.TestCase): self.assertEqual(r, "bum") self._check_test_key("one") + def test_nested_logging_context(self): + with LoggingContext(request="foo"): + nested_context = logcontext.nested_logging_context(suffix="bar") + self.assertEqual(nested_context.request, "foo-bar") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/utils.py b/tests/utils.py index 215226debf..aaed1149c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,9 @@ import atexit import hashlib import os +import time import uuid +import warnings from inspect import getcallargs from mock import Mock, patch @@ -237,20 +239,41 @@ def setup_test_homeserver( else: # We need to do cleanup on PostgreSQL def cleanup(): + import psycopg2 + # Close all the db pools hs.get_db_pool().close() + dropped = False + # Drop the test database db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER ) db_conn.autocommit = True cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for x in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + cur.close() db_conn.close() + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + if not LEAVE_DB: # Register the cleanup hook cleanup_func(cleanup) |