summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_tls.py3
-rw-r--r--tests/rest/admin/test_event_reports.py15
-rw-r--r--tests/rest/admin/test_media.py99
-rw-r--r--tests/rest/client/v2_alpha/test_report_event.py83
-rw-r--r--tests/storage/databases/__init__.py13
-rw-r--r--tests/storage/databases/main/__init__.py13
-rw-r--r--tests/storage/databases/main/test_events_worker.py96
-rw-r--r--tests/util/caches/test_descriptors.py6
-rw-r--r--tests/util/test_batching_queue.py78
9 files changed, 398 insertions, 8 deletions
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py

index 183034f7d4..dcf336416c 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py
@@ -74,12 +74,11 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= config = { "tls_certificate_path": os.path.join(config_dir, "cert.pem"), - "tls_fingerprints": [], } t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - t.read_certificate_from_disk(require_cert_and_key=False) + t.read_tls_certificate() warnings = self.flushWarnings() self.assertEqual(len(warnings), 1) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 29341bc6e9..f15d1cf6f7 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -64,7 +64,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): user_tok=self.admin_user_tok, ) for _ in range(5): - self._create_event_and_report( + self._create_event_and_report_without_parameters( room_id=self.room_id2, user_tok=self.admin_user_tok, ) @@ -378,6 +378,19 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def _create_event_and_report_without_parameters(self, room_id, user_tok): + """Create and report an event, but omit reason and score""" + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({}), + access_token=user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def _check_fields(self, content): """Checks that all attributes are present in an event report""" for c in content: diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index ac7b219700..f741121ea2 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -16,6 +16,8 @@ import json import os from binascii import unhexlify +from parameterized import parameterized + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client.v1 import login, profile, room @@ -562,3 +564,100 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) # Test that the file is deleted self.assertFalse(os.path.exists(local_path)) + + +class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + media_repo = hs.get_media_repository_resource() + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create media + upload_resource = media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + self.media_id = server_and_media_id.split("/")[1] + + self.url = "/_synapse/admin/v1/media/%s/%s" + + @parameterized.expand(["protect", "unprotect"]) + def test_no_auth(self, action: str): + """ + Try to protect media without authentication. + """ + + channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + @parameterized.expand(["protect", "unprotect"]) + def test_requester_is_no_admin(self, action: str): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + channel = self.make_request( + "POST", + self.url % (action, self.media_id), + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_protect_media(self): + """ + Tests that protect and unprotect a media is successfully + """ + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["safe_from_quarantine"]) + + # protect + channel = self.make_request( + "POST", + self.url % ("protect", self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertTrue(media_info["safe_from_quarantine"]) + + # unprotect + channel = self.make_request( + "POST", + self.url % ("unprotect", self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["safe_from_quarantine"]) diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py new file mode 100644
index 0000000000..1ec6b05e5b --- /dev/null +++ b/tests/rest/client/v2_alpha/test_report_event.py
@@ -0,0 +1,83 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import report_event + +from tests import unittest + + +class ReportEventTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) + resp = self.helper.send(self.room_id, tok=self.admin_user_tok) + self.event_id = resp["event_id"] + self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + + def test_reason_str_and_score_int(self): + data = {"reason": "this makes me sad", "score": -100} + self._assert_status(200, data) + + def test_no_reason(self): + data = {"score": 0} + self._assert_status(200, data) + + def test_no_score(self): + data = {"reason": "this makes me sad"} + self._assert_status(200, data) + + def test_no_reason_and_no_score(self): + data = {} + self._assert_status(200, data) + + def test_reason_int_and_score_str(self): + data = {"reason": 10, "score": "string"} + self._assert_status(400, data) + + def test_reason_zero_and_score_blank(self): + data = {"reason": 0, "score": ""} + self._assert_status(400, data) + + def test_reason_and_score_null(self): + data = {"reason": None, "score": None} + self._assert_status(400, data) + + def _assert_status(self, response_status, data): + channel = self.make_request( + "POST", + self.report_path, + json.dumps(data), + access_token=self.other_user_tok, + ) + self.assertEqual( + response_status, int(channel.result["code"]), msg=channel.result["body"] + ) diff --git a/tests/storage/databases/__init__.py b/tests/storage/databases/__init__.py new file mode 100644
index 0000000000..c24c7ecd92 --- /dev/null +++ b/tests/storage/databases/__init__.py
@@ -0,0 +1,13 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/databases/main/__init__.py b/tests/storage/databases/main/__init__.py new file mode 100644
index 0000000000..c24c7ecd92 --- /dev/null +++ b/tests/storage/databases/main/__init__.py
@@ -0,0 +1,13 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py new file mode 100644
index 0000000000..932970fd9a --- /dev/null +++ b/tests/storage/databases/main/test_events_worker.py
@@ -0,0 +1,96 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.logging.context import LoggingContext +from synapse.storage.databases.main.events_worker import EventsWorkerStore + +from tests import unittest + + +class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + # insert some test data + for rid in ("room1", "room2"): + self.get_success( + self.store.db_pool.simple_insert( + "rooms", + {"room_id": rid, "room_version": 4}, + ) + ) + + for idx, (rid, eid) in enumerate( + ( + ("room1", "event10"), + ("room1", "event11"), + ("room1", "event12"), + ("room2", "event20"), + ) + ): + self.get_success( + self.store.db_pool.simple_insert( + "events", + { + "event_id": eid, + "room_id": rid, + "topological_ordering": idx, + "stream_ordering": idx, + "type": "test", + "processed": True, + "outlier": False, + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "event_json", + { + "event_id": eid, + "room_id": rid, + "json": json.dumps({"type": "test", "room_id": rid}), + "internal_metadata": "{}", + "format_version": 3, + }, + ) + ) + + def test_simple(self): + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events("room1", ["event10", "event19"]) + ) + self.assertEquals(res, {"event10"}) + + # that should result in a single db query + self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + + # a second lookup of the same events should cause no queries + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events("room1", ["event10", "event19"]) + ) + self.assertEquals(res, {"event10"}) + self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + def test_query_via_event_cache(self): + # fetch an event into the event cache + self.get_success(self.store.get_event("event10")) + + # looking it up should now cause no db hits + with LoggingContext(name="test") as ctx: + res = self.get_success(self.store.have_seen_events("room1", ["event10"])) + self.assertEquals(res, {"event10"}) + self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index bbbc276697..0277998cbe 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -622,17 +622,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEquals(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) self.assertEquals(callcount[0], 1) self.assertEquals(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") self.assertEquals(callcount[0], 2) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 5def1e56c9..edf29e5b96 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py
@@ -14,7 +14,12 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable -from synapse.util.batching_queue import BatchingQueue +from synapse.util.batching_queue import ( + BatchingQueue, + number_in_flight, + number_of_keys, + number_queued, +) from tests.server import get_clock from tests.unittest import TestCase @@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase): def setUp(self): self.clock, hs_clock = get_clock() + # We ensure that we remove any existing metrics for "test_queue". + try: + number_queued.remove("test_queue") + number_of_keys.remove("test_queue") + number_in_flight.remove("test_queue") + except KeyError: + pass + self._pending_calls = [] self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) @@ -32,6 +45,41 @@ class BatchingQueueTestCase(TestCase): self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) + def _assert_metrics(self, queued, keys, in_flight): + """Assert that the metrics are correct""" + + self.assertEqual(len(number_queued.collect()), 1) + self.assertEqual(len(number_queued.collect()[0].samples), 1) + self.assertEqual( + number_queued.collect()[0].samples[0].labels, + {"name": self.queue._name}, + ) + self.assertEqual( + number_queued.collect()[0].samples[0].value, + queued, + "number_queued", + ) + + self.assertEqual(len(number_of_keys.collect()), 1) + self.assertEqual(len(number_of_keys.collect()[0].samples), 1) + self.assertEqual( + number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} + ) + self.assertEqual( + number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys" + ) + + self.assertEqual(len(number_in_flight.collect()), 1) + self.assertEqual(len(number_in_flight.collect()[0].samples), 1) + self.assertEqual( + number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} + ) + self.assertEqual( + number_in_flight.collect()[0].samples[0].value, + in_flight, + "number_in_flight", + ) + def test_simple(self): """Tests the basic case of calling `add_to_queue` once and having `_process_queue` return. @@ -41,6 +89,8 @@ class BatchingQueueTestCase(TestCase): queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo")) + self._assert_metrics(queued=1, keys=1, in_flight=1) + # The queue should wait a reactor tick before calling the processing # function. self.assertFalse(self._pending_calls) @@ -52,12 +102,15 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo"]) self.assertFalse(queue_d.called) + self._assert_metrics(queued=0, keys=0, in_flight=1) # Return value of the `_process_queue` should be propagated back. self._pending_calls.pop()[1].callback("bar") self.assertEqual(self.successResultOf(queue_d), "bar") + self._assert_metrics(queued=0, keys=0, in_flight=0) + def test_batching(self): """Test that multiple calls at the same time get batched up into one call to `_process_queue`. @@ -68,6 +121,8 @@ class BatchingQueueTestCase(TestCase): queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + self._assert_metrics(queued=2, keys=1, in_flight=2) + self.clock.pump([0]) # We should see only *one* call to `_process_queue` @@ -75,12 +130,14 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"]) self.assertFalse(queue_d1.called) self.assertFalse(queue_d2.called) + self._assert_metrics(queued=0, keys=0, in_flight=2) # Return value of the `_process_queue` should be propagated back to both. self._pending_calls.pop()[1].callback("bar") self.assertEqual(self.successResultOf(queue_d1), "bar") self.assertEqual(self.successResultOf(queue_d2), "bar") + self._assert_metrics(queued=0, keys=0, in_flight=0) def test_queuing(self): """Test that we queue up requests while a `_process_queue` is being @@ -92,13 +149,20 @@ class BatchingQueueTestCase(TestCase): queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) self.clock.pump([0]) + self.assertEqual(len(self._pending_calls), 1) + + # We queue up work after the process function has been called, testing + # that they get correctly queued up. queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3")) # We should see only *one* call to `_process_queue` self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo1"]) self.assertFalse(queue_d1.called) self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + self._assert_metrics(queued=2, keys=1, in_flight=3) # Return value of the `_process_queue` should be propagated back to the # first. @@ -106,18 +170,24 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d1), "bar1") self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + self._assert_metrics(queued=2, keys=1, in_flight=2) # We should now see a second call to `_process_queue` self.clock.pump([0]) self.assertEqual(len(self._pending_calls), 1) - self.assertEqual(self._pending_calls[0][0], ["foo2"]) + self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"]) self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + self._assert_metrics(queued=0, keys=0, in_flight=2) # Return value of the `_process_queue` should be propagated back to the # second. self._pending_calls.pop()[1].callback("bar2") self.assertEqual(self.successResultOf(queue_d2), "bar2") + self.assertEqual(self.successResultOf(queue_d3), "bar2") + self._assert_metrics(queued=0, keys=0, in_flight=0) def test_different_keys(self): """Test that calls to different keys get processed in parallel.""" @@ -140,6 +210,7 @@ class BatchingQueueTestCase(TestCase): self.assertFalse(queue_d1.called) self.assertFalse(queue_d2.called) self.assertFalse(queue_d3.called) + self._assert_metrics(queued=1, keys=1, in_flight=3) # Return value of the `_process_queue` should be propagated back to the # first. @@ -148,6 +219,7 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d1), "bar1") self.assertFalse(queue_d2.called) self.assertFalse(queue_d3.called) + self._assert_metrics(queued=1, keys=1, in_flight=2) # Return value of the `_process_queue` should be propagated back to the # second. @@ -161,9 +233,11 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo3"]) self.assertFalse(queue_d3.called) + self._assert_metrics(queued=0, keys=0, in_flight=1) # Return value of the `_process_queue` should be propagated back to the # third deferred. self._pending_calls.pop()[1].callback("bar4") self.assertEqual(self.successResultOf(queue_d3), "bar4") + self._assert_metrics(queued=0, keys=0, in_flight=0)