summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py64
-rw-r--r--tests/federation/test_federation_sender.py5
-rw-r--r--tests/rest/admin/test_room.py42
-rw-r--r--tests/rest/client/test_auth.py2
-rw-r--r--tests/unittest.py11
-rw-r--r--tests/util/test_glob_to_regex.py59
6 files changed, 156 insertions, 27 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3aa9ba3c43..3c9ca52922 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,6 +31,7 @@ from synapse.types import Requester
 
 from tests import unittest
 from tests.test_utils import simple_async_mock
+from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
 
@@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_failure(self.auth.get_user_by_req(request), AuthError)
 
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with the ID of a valid device for the user in question,
+        the requester instance tracks that device ID.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"DOPPELDEVICE"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a truth-y value
+        self.store.get_device_opt = simple_async_mock({"hidden": False})
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.assertEquals(
+            requester.user.to_string(), masquerading_user_id.decode("utf8")
+        )
+        self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with an ID that is not a valid device ID for the user in question,
+        the request fails with the appropriate error code.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a truth-y value
+        self.store.get_device_opt = simple_async_mock(None)
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
+        self.assertEquals(failure.value.code, 400)
+        self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = simple_async_mock(
             TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b457dad6d2..b2376e2db9 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
 
         # expect signing key update edu
-        self.assertEqual(len(self.edus), 1)
+        self.assertEqual(len(self.edus), 2)
+        self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
         self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
 
         # sign the devices
@@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
     ) -> None:
         """Check that the txn has an EDU with a signing key update."""
         edus = txn["edus"]
-        self.assertEqual(len(edus), 1)
+        self.assertEqual(len(edus), 2)
 
     def generate_and_upload_device_signing_key(
         self, user_id: str, device_id: str
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index d3858e460d..22f9aa6234 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -83,7 +83,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
     def test_room_does_not_exist(self):
         """
-        Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+        Check that unknown rooms/server return 200
         """
         url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test"
 
@@ -94,8 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
     def test_room_is_not_valid(self):
         """
@@ -508,27 +507,36 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    @parameterized.expand(
-        [
-            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
-            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
-            ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
-        ]
-    )
-    def test_room_does_not_exist(self, method: str, url: str):
-        """
-        Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+    def test_room_does_not_exist(self):
         """
+        Check that unknown rooms/server return 200
 
+        This is important, as it allows incomplete vestiges of rooms to be cleared up
+        even if the create event/etc is missing.
+        """
+        room_id = "!unknown:test"
         channel = self.make_request(
-            method,
-            url % "!unknown:test",
+            "DELETE",
+            f"/_synapse/admin/v2/rooms/{room_id}",
             content={},
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        # get status
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v2/rooms/{room_id}/delete_status",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+        self.assertEqual("complete", channel.json_body["results"][0]["status"])
+        self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
 
     @parameterized.expand(
         [
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index aa8ad6d2e1..72bbc87b4a 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -703,7 +703,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         login_response1 = self.make_request(
             "POST",
             "/_matrix/client/r0/login",
-            {"org.matrix.msc2918.refresh_token": True, **body},
+            {"refresh_token": True, **body},
         )
         self.assertEqual(login_response1.code, 200, login_response1.result)
         self.assertApproximates(
diff --git a/tests/unittest.py b/tests/unittest.py
index eea0903f05..1431848367 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,16 +331,13 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """Block until all background database updates have completed.
-
-        Note that callers must ensure there's a store property created on the
-        testcase.
-        """
+        """Block until all background database updates have completed."""
+        store = self.hs.get_datastore()
         while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
+            store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db_pool.updates.do_next_background_update(False), by=0.1
+                store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
     def make_homeserver(self, reactor, clock):
diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py
new file mode 100644
index 0000000000..220accb92b
--- /dev/null
+++ b/tests/util/test_glob_to_regex.py
@@ -0,0 +1,59 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util import glob_to_regex
+
+from tests.unittest import TestCase
+
+
+class GlobToRegexTestCase(TestCase):
+    def test_literal_match(self):
+        """patterns without wildcards should match"""
+        pat = glob_to_regex("foobaz")
+        self.assertTrue(
+            pat.match("FoobaZ"), "patterns should match and be case-insensitive"
+        )
+        self.assertFalse(
+            pat.match("x foobaz"), "pattern should not match at word boundaries"
+        )
+
+    def test_wildcard_match(self):
+        pat = glob_to_regex("f?o*baz")
+
+        self.assertTrue(
+            pat.match("FoobarbaZ"),
+            "* should match string and pattern should be case-insensitive",
+        )
+        self.assertTrue(pat.match("foobaz"), "* should match 0 characters")
+        self.assertFalse(pat.match("fooxaz"), "the character after * must match")
+        self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters")
+        self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters")
+
+    def test_multi_wildcard(self):
+        """patterns with multiple wildcards in a row should match"""
+        pat = glob_to_regex("**baz")
+        self.assertTrue(pat.match("agsgsbaz"), "** should match any string")
+        self.assertTrue(pat.match("baz"), "** should match the empty string")
+        self.assertEqual(pat.pattern, r"\A.{0,}baz\Z")
+
+        pat = glob_to_regex("*?baz")
+        self.assertTrue(pat.match("agsgsbaz"), "*? should match any string")
+        self.assertTrue(pat.match("abaz"), "*? should match a single char")
+        self.assertFalse(pat.match("baz"), "*? should not match the empty string")
+        self.assertEqual(pat.pattern, r"\A.{1,}baz\Z")
+
+        pat = glob_to_regex("a?*?*?baz")
+        self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars")
+        self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars")
+        self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars")
+        self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z")