summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py64
1 files changed, 64 insertions, 0 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3aa9ba3c43..a2dfa1ed05 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,6 +31,7 @@ from synapse.types import Requester
 
 from tests import unittest
 from tests.test_utils import simple_async_mock
+from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
 
@@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_failure(self.auth.get_user_by_req(request), AuthError)
 
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with the ID of a valid device for the user in question,
+        the requester instance tracks that device ID.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"DOPPELDEVICE"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a truth-y value
+        self.store.get_device = simple_async_mock({"hidden": False})
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.assertEquals(
+            requester.user.to_string(), masquerading_user_id.decode("utf8")
+        )
+        self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with an ID that is not a valid device ID for the user in question,
+        the request fails with the appropriate error code.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a falsey value
+        self.store.get_device = simple_async_mock(None)
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
+        self.assertEquals(failure.value.code, 400)
+        self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = simple_async_mock(
             TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")