diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3aa9ba3c43..3c9ca52922 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,6 +31,7 @@ from synapse.types import Requester
from tests import unittest
from tests.test_utils import simple_async_mock
+from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), AuthError)
+ @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+ def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+ """
+ Tests that when an application service passes the device_id URL parameter
+ with the ID of a valid device for the user in question,
+ the requester instance tracks that device ID.
+ """
+ masquerading_user_id = b"@doppelganger:matrix.org"
+ masquerading_device_id = b"DOPPELDEVICE"
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+ )
+ app_service.is_interested_in_user = Mock(return_value=True)
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+ self.store.get_user_by_access_token = simple_async_mock(None)
+ # This also needs to just return a truth-y value
+ self.store.get_device_opt = simple_async_mock({"hidden": False})
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.args[b"user_id"] = [masquerading_user_id]
+ request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.assertEquals(
+ requester.user.to_string(), masquerading_user_id.decode("utf8")
+ )
+ self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+
+ @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+ """
+ Tests that when an application service passes the device_id URL parameter
+ with an ID that is not a valid device ID for the user in question,
+ the request fails with the appropriate error code.
+ """
+ masquerading_user_id = b"@doppelganger:matrix.org"
+ masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+ )
+ app_service.is_interested_in_user = Mock(return_value=True)
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+ self.store.get_user_by_access_token = simple_async_mock(None)
+ # This also needs to just return a truth-y value
+ self.store.get_device_opt = simple_async_mock(None)
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.args[b"user_id"] = [masquerading_user_id]
+ request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
+ self.assertEquals(failure.value.code, 400)
+ self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b457dad6d2..b2376e2db9 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
# expect signing key update edu
- self.assertEqual(len(self.edus), 1)
+ self.assertEqual(len(self.edus), 2)
+ self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
# sign the devices
@@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
) -> None:
"""Check that the txn has an EDU with a signing key update."""
edus = txn["edus"]
- self.assertEqual(len(edus), 1)
+ self.assertEqual(len(edus), 2)
def generate_and_upload_device_signing_key(
self, user_id: str, device_id: str
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index d3858e460d..22f9aa6234 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -83,7 +83,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_room_does_not_exist(self):
"""
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ Check that unknown rooms/server return 200
"""
url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test"
@@ -94,8 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_room_is_not_valid(self):
"""
@@ -508,27 +507,36 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- @parameterized.expand(
- [
- ("DELETE", "/_synapse/admin/v2/rooms/%s"),
- ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
- ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
- ]
- )
- def test_room_does_not_exist(self, method: str, url: str):
- """
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ def test_room_does_not_exist(self):
"""
+ Check that unknown rooms/server return 200
+ This is important, as it allows incomplete vestiges of rooms to be cleared up
+ even if the create event/etc is missing.
+ """
+ room_id = "!unknown:test"
channel = self.make_request(
- method,
- url % "!unknown:test",
+ "DELETE",
+ f"/_synapse/admin/v2/rooms/{room_id}",
content={},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/rooms/{room_id}/delete_status",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
@parameterized.expand(
[
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index aa8ad6d2e1..72bbc87b4a 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -703,7 +703,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
login_response1 = self.make_request(
"POST",
"/_matrix/client/r0/login",
- {"org.matrix.msc2918.refresh_token": True, **body},
+ {"refresh_token": True, **body},
)
self.assertEqual(login_response1.code, 200, login_response1.result)
self.assertApproximates(
diff --git a/tests/unittest.py b/tests/unittest.py
index eea0903f05..1431848367 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,16 +331,13 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed.
-
- Note that callers must ensure there's a store property created on the
- testcase.
- """
+ """Block until all background database updates have completed."""
+ store = self.hs.get_datastore()
while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
+ store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db_pool.updates.do_next_background_update(False), by=0.1
+ store.db_pool.updates.do_next_background_update(False), by=0.1
)
def make_homeserver(self, reactor, clock):
diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py
new file mode 100644
index 0000000000..220accb92b
--- /dev/null
+++ b/tests/util/test_glob_to_regex.py
@@ -0,0 +1,59 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util import glob_to_regex
+
+from tests.unittest import TestCase
+
+
+class GlobToRegexTestCase(TestCase):
+ def test_literal_match(self):
+ """patterns without wildcards should match"""
+ pat = glob_to_regex("foobaz")
+ self.assertTrue(
+ pat.match("FoobaZ"), "patterns should match and be case-insensitive"
+ )
+ self.assertFalse(
+ pat.match("x foobaz"), "pattern should not match at word boundaries"
+ )
+
+ def test_wildcard_match(self):
+ pat = glob_to_regex("f?o*baz")
+
+ self.assertTrue(
+ pat.match("FoobarbaZ"),
+ "* should match string and pattern should be case-insensitive",
+ )
+ self.assertTrue(pat.match("foobaz"), "* should match 0 characters")
+ self.assertFalse(pat.match("fooxaz"), "the character after * must match")
+ self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters")
+ self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters")
+
+ def test_multi_wildcard(self):
+ """patterns with multiple wildcards in a row should match"""
+ pat = glob_to_regex("**baz")
+ self.assertTrue(pat.match("agsgsbaz"), "** should match any string")
+ self.assertTrue(pat.match("baz"), "** should match the empty string")
+ self.assertEqual(pat.pattern, r"\A.{0,}baz\Z")
+
+ pat = glob_to_regex("*?baz")
+ self.assertTrue(pat.match("agsgsbaz"), "*? should match any string")
+ self.assertTrue(pat.match("abaz"), "*? should match a single char")
+ self.assertFalse(pat.match("baz"), "*? should not match the empty string")
+ self.assertEqual(pat.pattern, r"\A.{1,}baz\Z")
+
+ pat = glob_to_regex("a?*?*?baz")
+ self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars")
+ self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars")
+ self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars")
+ self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z")
|