summary refs log tree commit diff
path: root/tests/storage/test_client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_client_ips.py')
-rw-r--r--tests/storage/test_client_ips.py183
1 files changed, 164 insertions, 19 deletions
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py

index bd6fda6cb1..4577e9422b 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,34 +14,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -import tests.unittest -import tests.utils +from synapse.http.site import XForwardedForRequest +from synapse.rest.client.v1 import admin, login + +from tests import unittest -class ClientIpStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.clock = None # type: tests.utils.MockClock +class ClientIpStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - self.store = hs.get_datastore() - self.clock = hs.get_clock() + def prepare(self, hs, reactor, clock): + self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_insert_new_client_ip(self): - self.clock.now = 12345678 + self.reactor.advance(12345678) + user_id = "@user:id" - yield self.store.insert_client_ip( - user_id, - "access_token", "ip", "user_agent", "device_id", + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) r = result[(user_id, "device_id")] self.assertDictContainsSubset( @@ -52,5 +59,143 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): "user_agent": "user_agent", "last_seen": 12345678000, }, - r + r, + ) + + def test_disabled_monthly_active_user(self): + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) + ) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) + self.assertFalse(active) + + def test_adding_monthly_active_user_when_full(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + user_id = "@user:server" + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(lots_of_users) + ) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) + ) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) + self.assertFalse(active) + + def test_adding_monthly_active_user_when_space(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) + self.assertFalse(active) + + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) + ) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) + self.assertTrue(active) + + def test_updating_monthly_active_user_when_space(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + self.get_success( + self.store.register(user_id=user_id, token="123", password_hash=None) + ) + + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) + self.assertFalse(active) + + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) + ) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) + self.assertTrue(active) + + +class ClientIpAuthTestCase(unittest.HomeserverTestCase): + + servlets = [admin.register_servlets, login.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): + self.store = self.hs.get_datastore() + self.user_id = self.register_user("bob", "abc123", True) + + def test_request_with_xforwarded(self): + """ + The IP in X-Forwarded-For is entered into the client IPs table. + """ + self._runtest( + {b"X-Forwarded-For": b"127.9.0.1"}, + "127.9.0.1", + {"request": XForwardedForRequest}, + ) + + def test_request_from_getPeer(self): + """ + The IP returned by getPeer is entered into the client IPs table, if + there's no X-Forwarded-For header. + """ + self._runtest({}, "127.0.0.1", {}) + + def _runtest(self, headers, expected_ip, make_request_args): + device_id = "bleb" + + access_token = self.login("bob", "abc123", device_id=device_id) + + # Advance to a known time + self.reactor.advance(123456 - self.reactor.seconds()) + + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/admin/users/" + self.user_id, + access_token=access_token, + **make_request_args + ) + request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") + + # Add the optional headers + for h, v in headers.items(): + request.requestHeaders.addRawHeader(h, v) + self.render(request) + + # Advance so the save loop occurs + self.reactor.advance(100) + + result = self.get_success( + self.store.get_last_client_ip_by_device(self.user_id, device_id) + ) + r = result[(self.user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": self.user_id, + "device_id": device_id, + "ip": expected_ip, + "user_agent": "Mozzila pizza", + "last_seen": 123456100, + }, + r, )