summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py170
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py18
-rw-r--r--tests/util/test_batching_queue.py37
3 files changed, 111 insertions, 114 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2775dfd880..745c295d3b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import time
+from typing import Dict, List
 from unittest.mock import Mock
 
 import attr
@@ -21,7 +22,6 @@ import signedjson.sign
 from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
-from twisted.internet import defer
 from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.api.errors import SynapseError
@@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # deferred completes.
         first_lookup_deferred = Deferred()
 
-        async def first_lookup_fetch(keys_to_fetch):
-            self.assertEquals(current_context().request.id, "context_11")
-            self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
+        async def first_lookup_fetch(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            # self.assertEquals(current_context().request.id, "context_11")
+            self.assertEqual(server_name, "server10")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 0)
 
             await make_deferred_yieldable(first_lookup_deferred)
-            return {
-                "server10": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
-                }
-            }
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.side_effect = first_lookup_fetch
 
         async def first_lookup():
             with LoggingContext("context_11", request=FakeRequest("context_11")):
                 res_deferreds = kr.verify_json_objects_for_server(
-                    [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
+                    [("server10", json1, 0), ("server11", {}, 0)]
                 )
 
                 # the unsigned json should be rejected pretty quickly
@@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         d0 = ensureDeferred(first_lookup())
 
+        self.pump()
+
         mock_fetcher.get_keys.assert_called_once()
 
         # a second request for a server with outstanding requests
         # should block rather than start a second call
 
-        async def second_lookup_fetch(keys_to_fetch):
-            self.assertEquals(current_context().request.id, "context_12")
-            return {
-                "server10": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
-                }
-            }
+        async def second_lookup_fetch(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            # self.assertEquals(current_context().request.id, "context_12")
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.reset_mock()
         mock_fetcher.get_keys.side_effect = second_lookup_fetch
@@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         async def second_lookup():
             with LoggingContext("context_12", request=FakeRequest("context_12")):
                 res_deferreds_2 = kr.verify_json_objects_for_server(
-                    [("server10", json1, 0, "test")]
+                    [
+                        (
+                            "server10",
+                            json1,
+                            0,
+                        )
+                    ]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
                 second_lookup_state[0] = 1
@@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        d = kr.verify_json_for_server("server9", {}, 0)
         self.get_failure(d, SynapseError)
 
         # should succeed on a signed object
-        d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+        d = kr.verify_json_for_server("server9", json1, 500)
         # self.assertFalse(d.called)
         self.get_success(d)
 
@@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        d = kr.verify_json_for_server("server9", {}, 0)
         self.get_failure(d, SynapseError)
 
         # should fail on a signed object with a non-zero minimum_valid_until_ms,
         # as it tries to refetch the keys and fails.
-        d = _verify_json_for_server(
-            kr, "server9", json1, 500, "test signed non-zero min"
-        )
+        d = kr.verify_json_for_server("server9", json1, 500)
         self.get_failure(d, SynapseError)
 
         # We expect the keyring tried to refetch the key once.
         mock_fetcher.get_keys.assert_called_once_with(
-            {"server9": {get_key_id(key1): 500}}
+            "server9", [get_key_id(key1)], 500
         )
 
         # should succeed on a signed object with a 0 minimum_valid_until_ms
-        d = _verify_json_for_server(
-            kr, "server9", json1, 0, "test signed with zero min"
+        d = kr.verify_json_for_server(
+            "server9",
+            json1,
+            0,
         )
         self.get_success(d)
 
@@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key(1)
 
-        async def get_keys(keys_to_fetch):
+        async def get_keys(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
             # there should only be one request object (with the max validity)
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
 
-            return {
-                "server1": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                }
-            }
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
 
         mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # the first request should succeed; the second should fail because the key
         # has expired
         results = kr.verify_json_objects_for_server(
-            [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+            [
+                (
+                    "server1",
+                    json1,
+                    500,
+                ),
+                ("server1", json1, 1500),
+            ]
         )
         self.assertEqual(len(results), 2)
         self.get_success(results[0])
@@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """If the first fetcher cannot provide a recent enough key, we fall back"""
         key1 = signedjson.key.generate_signing_key(1)
 
-        async def get_keys1(keys_to_fetch):
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return {
-                "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
-            }
-
-        async def get_keys2(keys_to_fetch):
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return {
-                "server1": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                }
-            }
+        async def get_keys1(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+
+        async def get_keys2(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
 
         mock_fetcher1 = Mock()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server1", key1)
 
         results = kr.verify_json_objects_for_server(
-            [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+            [
+                (
+                    "server1",
+                    json1,
+                    1200,
+                ),
+                (
+                    "server1",
+                    json1,
+                    1500,
+                ),
+            ]
         )
         self.assertEqual(len(results), 2)
         self.get_success(results[0])
@@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.get_json.side_effect = get_json
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
 
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
         self.assertEqual(keys, {})
 
 
@@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        self.assertIn(SERVER_NAME, keys)
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        self.assertIn(testverifykey_id, keys)
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        self.assertIn(SERVER_NAME, keys)
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        self.assertIn(testverifykey_id, keys)
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
-            keys_to_fetch = {SERVER_NAME: {"key1": 0}}
             self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
-            return self.get_success(fetcher.get_keys(keys_to_fetch))
+            return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
 
         # start with a valid response so we can check we are testing the right thing
         response = build_response()
         keys = get_key_from_perspectives(response)
-        k = keys[SERVER_NAME][testverifykey_id]
+        k = keys[testverifykey_id]
         self.assertEqual(k.verify_key, testverifykey)
 
         # remove the perspectives server's signature
@@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 def get_key_id(key):
     """Get the matrix ID tag for a given SigningKey or VerifyKey"""
     return "%s:%s" % (key.alg, key.version)
-
-
-@defer.inlineCallbacks
-def run_in_context(f, *args, **kwargs):
-    with LoggingContext("testctx"):
-        rv = yield f(*args, **kwargs)
-    return rv
-
-
-def _verify_json_for_server(kr, *args):
-    """thin wrapper around verify_json_for_server which makes sure it is wrapped
-    with the patched defer.inlineCallbacks.
-    """
-
-    @defer.inlineCallbacks
-    def v():
-        rv1 = yield kr.verify_json_for_server(*args)
-        return rv1
-
-    return run_in_context(v)
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 3b275bc23b..a75c0ea3f0 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -208,10 +208,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         keyid = "ed25519:%s" % (testkey.version,)
 
         fetcher = PerspectivesKeyFetcher(self.hs2)
-        d = fetcher.get_keys({"targetserver": {keyid: 1000}})
+        d = fetcher.get_keys("targetserver", [keyid], 1000)
         res = self.get_success(d)
-        self.assertIn("targetserver", res)
-        keyres = res["targetserver"][keyid]
+        self.assertIn(keyid, res)
+        keyres = res[keyid]
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
@@ -230,10 +230,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         keyid = "ed25519:%s" % (testkey.version,)
 
         fetcher = PerspectivesKeyFetcher(self.hs2)
-        d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+        d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
         res = self.get_success(d)
-        self.assertIn(self.hs.hostname, res)
-        keyres = res[self.hs.hostname][keyid]
+        self.assertIn(keyid, res)
+        keyres = res[keyid]
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
@@ -247,10 +247,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         keyid = "ed25519:%s" % (self.hs_signing_key.version,)
 
         fetcher = PerspectivesKeyFetcher(self.hs2)
-        d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+        d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
         res = self.get_success(d)
-        self.assertIn(self.hs.hostname, res)
-        keyres = res[self.hs.hostname][keyid]
+        self.assertIn(keyid, res)
+        keyres = res[keyid]
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index edf29e5b96..07be57d72c 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -45,37 +45,32 @@ class BatchingQueueTestCase(TestCase):
         self._pending_calls.append((values, d))
         return await make_deferred_yieldable(d)
 
+    def _get_sample_with_name(self, metric, name) -> int:
+        """For a prometheus metric get the value of the sample that has a
+        matching "name" label.
+        """
+        for sample in metric.collect()[0].samples:
+            if sample.labels.get("name") == name:
+                return sample.value
+
+        self.fail("Found no matching sample")
+
     def _assert_metrics(self, queued, keys, in_flight):
         """Assert that the metrics are correct"""
 
-        self.assertEqual(len(number_queued.collect()), 1)
-        self.assertEqual(len(number_queued.collect()[0].samples), 1)
+        sample = self._get_sample_with_name(number_queued, self.queue._name)
         self.assertEqual(
-            number_queued.collect()[0].samples[0].labels,
-            {"name": self.queue._name},
-        )
-        self.assertEqual(
-            number_queued.collect()[0].samples[0].value,
+            sample,
             queued,
             "number_queued",
         )
 
-        self.assertEqual(len(number_of_keys.collect()), 1)
-        self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
-        self.assertEqual(
-            number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
-        )
-        self.assertEqual(
-            number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
-        )
+        sample = self._get_sample_with_name(number_of_keys, self.queue._name)
+        self.assertEqual(sample, keys, "number_of_keys")
 
-        self.assertEqual(len(number_in_flight.collect()), 1)
-        self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
-        self.assertEqual(
-            number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
-        )
+        sample = self._get_sample_with_name(number_in_flight, self.queue._name)
         self.assertEqual(
-            number_in_flight.collect()[0].samples[0].value,
+            sample,
             in_flight,
             "number_in_flight",
         )