summary refs log tree commit diff
path: root/tests/storage/test_client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_client_ips.py')
-rw-r--r--tests/storage/test_client_ips.py202
1 files changed, 167 insertions, 35 deletions
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index c9b02a062b..2ffbb9f14f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,35 +13,45 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+import hashlib
+import hmac
+import json
+
 from mock import Mock
 
 from twisted.internet import defer
 
-import tests.unittest
-import tests.utils
+from synapse.http.site import XForwardedForRequest
+from synapse.rest.client.v1 import admin, login
+
+from tests import unittest
 
 
-class ClientIpStoreTestCase(tests.unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
-        self.store = None  # type: synapse.storage.DataStore
-        self.clock = None  # type: tests.utils.MockClock
+class ClientIpStoreTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        return hs
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+    def prepare(self, hs, reactor, clock):
         self.store = self.hs.get_datastore()
-        self.clock = self.hs.get_clock()
 
-    @defer.inlineCallbacks
     def test_insert_new_client_ip(self):
-        self.clock.now = 12345678
+        self.reactor.advance(12345678)
+
         user_id = "@user:id"
-        yield self.store.insert_client_ip(
-            user_id, "access_token", "ip", "user_agent", "device_id"
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", "device_id"
+            )
         )
 
-        result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
+        # Trigger the storage loop
+        self.reactor.advance(10)
+
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, "device_id")
+        )
 
         r = result[(user_id, "device_id")]
         self.assertDictContainsSubset(
@@ -55,18 +66,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
             r,
         )
 
-    @defer.inlineCallbacks
     def test_disabled_monthly_active_user(self):
         self.hs.config.limit_usage_by_mau = False
         self.hs.config.max_mau_value = 50
         user_id = "@user:server"
-        yield self.store.insert_client_ip(
-            user_id, "access_token", "ip", "user_agent", "device_id"
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", "device_id"
+            )
         )
-        active = yield self.store.user_last_seen_monthly_active(user_id)
+        active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
 
-    @defer.inlineCallbacks
     def test_adding_monthly_active_user_when_full(self):
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 50
@@ -76,38 +87,159 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
         self.store.get_monthly_active_count = Mock(
             return_value=defer.succeed(lots_of_users)
         )
-        yield self.store.insert_client_ip(
-            user_id, "access_token", "ip", "user_agent", "device_id"
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", "device_id"
+            )
         )
-        active = yield self.store.user_last_seen_monthly_active(user_id)
+        active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
 
-    @defer.inlineCallbacks
     def test_adding_monthly_active_user_when_space(self):
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 50
         user_id = "@user:server"
-        active = yield self.store.user_last_seen_monthly_active(user_id)
+        active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
 
-        yield self.store.insert_client_ip(
-            user_id, "access_token", "ip", "user_agent", "device_id"
+        # Trigger the saving loop
+        self.reactor.advance(10)
+
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", "device_id"
+            )
         )
-        active = yield self.store.user_last_seen_monthly_active(user_id)
+        active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertTrue(active)
 
-    @defer.inlineCallbacks
     def test_updating_monthly_active_user_when_space(self):
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 50
         user_id = "@user:server"
-        yield self.store.register(user_id=user_id, token="123", password_hash=None)
+        self.get_success(
+            self.store.register(user_id=user_id, token="123", password_hash=None)
+        )
 
-        active = yield self.store.user_last_seen_monthly_active(user_id)
+        active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
 
-        yield self.store.insert_client_ip(
-            user_id, "access_token", "ip", "user_agent", "device_id"
+        # Trigger the saving loop
+        self.reactor.advance(10)
+
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", "device_id"
+            )
         )
-        active = yield self.store.user_last_seen_monthly_active(user_id)
+        active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertTrue(active)
+
+
+class ClientIpAuthTestCase(unittest.HomeserverTestCase):
+
+    servlets = [admin.register_servlets, login.register_servlets]
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        return hs
+
+    def prepare(self, hs, reactor, clock):
+        self.hs.config.registration_shared_secret = u"shared"
+        self.store = self.hs.get_datastore()
+
+        # Create the user
+        request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
+        self.render(request)
+        nonce = channel.json_body["nonce"]
+
+        want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+        want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+        want_mac = want_mac.hexdigest()
+
+        body = json.dumps(
+            {
+                "nonce": nonce,
+                "username": "bob",
+                "password": "abc123",
+                "admin": True,
+                "mac": want_mac,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, 200)
+        self.user_id = channel.json_body["user_id"]
+
+    def test_request_with_xforwarded(self):
+        """
+        The IP in X-Forwarded-For is entered into the client IPs table.
+        """
+        self._runtest(
+            {b"X-Forwarded-For": b"127.9.0.1"},
+            "127.9.0.1",
+            {"request": XForwardedForRequest},
+        )
+
+    def test_request_from_getPeer(self):
+        """
+        The IP returned by getPeer is entered into the client IPs table, if
+        there's no X-Forwarded-For header.
+        """
+        self._runtest({}, "127.0.0.1", {})
+
+    def _runtest(self, headers, expected_ip, make_request_args):
+        device_id = "bleb"
+
+        body = json.dumps(
+            {
+                "type": "m.login.password",
+                "user": "bob",
+                "password": "abc123",
+                "device_id": device_id,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/login", body.encode('utf8'), **make_request_args
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        access_token = channel.json_body["access_token"].encode('ascii')
+
+        # Advance to a known time
+        self.reactor.advance(123456 - self.reactor.seconds())
+
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/admin/users/" + self.user_id,
+            body.encode('utf8'),
+            access_token=access_token,
+            **make_request_args
+        )
+        request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
+
+        # Add the optional headers
+        for h, v in headers.items():
+            request.requestHeaders.addRawHeader(h, v)
+        self.render(request)
+
+        # Advance so the save loop occurs
+        self.reactor.advance(100)
+
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(self.user_id, device_id)
+        )
+        r = result[(self.user_id, device_id)]
+        self.assertDictContainsSubset(
+            {
+                "user_id": self.user_id,
+                "device_id": device_id,
+                "ip": expected_ip,
+                "user_agent": "Mozzila pizza",
+                "last_seen": 123456100,
+            },
+            r,
+        )