From 1d11b452b70c768e4919bd9cf6bcaeda2050a3d4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 3 Mar 2022 10:43:06 -0500 Subject: Use the proper serialization format when bundling aggregations. (#12090) This ensures that the `latest_event` field of the bundled aggregation for threads uses the same format as the other events in the response. --- tests/rest/client/test_relations.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 709f851a38..53062b41de 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -704,10 +704,8 @@ class RelationsTestCase(BaseRelationsTestCase): } }, "event_id": thread_2, - "room_id": self.room, "sender": self.user_id, "type": "m.room.test", - "user_id": self.user_id, }, relations_dict[RelationTypes.THREAD].get("latest_event"), ) -- cgit 1.5.1 From 7e91107be1a4287873266e588a3c5b415279f4c8 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 3 Mar 2022 17:05:44 +0100 Subject: Add type hints to `tests/rest` (#12146) * Add type hints to `tests/rest` * newsfile * change import from `SigningKey` --- changelog.d/12146.misc | 1 + mypy.ini | 7 +--- tests/rest/key/v2/test_remote_key_resource.py | 44 ++++++++++++++-------- tests/rest/media/v1/test_base.py | 4 +- tests/rest/media/v1/test_filepath.py | 48 ++++++++++++------------ tests/rest/media/v1/test_html_preview.py | 54 +++++++++++++-------------- tests/rest/media/v1/test_oembed.py | 10 ++--- tests/rest/test_health.py | 8 ++-- tests/rest/test_well_known.py | 20 +++++----- 9 files changed, 104 insertions(+), 92 deletions(-) create mode 100644 changelog.d/12146.misc (limited to 'tests/rest') diff --git a/changelog.d/12146.misc b/changelog.d/12146.misc new file mode 100644 index 0000000000..3ca7c47212 --- /dev/null +++ b/changelog.d/12146.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest`. diff --git a/mypy.ini b/mypy.ini index 10971b7225..481e8a5366 100644 --- a/mypy.ini +++ b/mypy.ini @@ -89,8 +89,6 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py - |tests/rest/key/v2/test_remote_key_resource.py - |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_url_preview.py |tests/scripts/test_new_matrix_user.py @@ -254,10 +252,7 @@ disallow_untyped_defs = True [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True -[mypy-tests.rest.admin.*] -disallow_untyped_defs = True - -[mypy-tests.rest.client.*] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 4672a68596..978c252f84 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -13,19 +13,24 @@ # limitations under the License. import urllib.parse from io import BytesIO, StringIO +from typing import Any, Dict, Optional, Union from unittest.mock import Mock import signedjson.key from canonicaljson import encode_canonical_json -from nacl.signing import SigningKey from signedjson.sign import sign_json +from signedjson.types import SigningKey -from twisted.web.resource import NoResource +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string @@ -35,11 +40,11 @@ from tests.utils import default_config class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() return self.setup_test_homeserver(federation_http_client=self.http_client) - def create_test_resource(self): + def create_test_resource(self) -> Resource: return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) @@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - async def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json( + destination: str, + path: str, + ignore_backoff: bool = False, + **kwargs: Any, + ) -> Union[JsonDict, list]: self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): Checks that the response is a 200 and returns the decoded json body. """ channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): resp = channel.json_body return resp - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a remote key""" SERVER_NAME = "remote.server" testkey = signedjson.key.generate_signing_key("ver1") @@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): self.assertIn(SERVER_NAME, keys[0]["signatures"]) self.assertIn(self.hs.hostname, keys[0]["signatures"]) - def test_get_own_key(self): + def test_get_own_key(self) -> None: """Fetch our own key""" testkey = signedjson.key.generate_signing_key("ver1") self.expect_outgoing_key_request(self.hs.hostname, testkey) @@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # replace the signing key with our own @@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # make a second homeserver, configured to use the first one as a key notary self.http_client2 = Mock() config = default_config(name="keyclient") @@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - async def post_json(destination, path, data): + async def post_json( + destination: str, path: str, data: Optional[JsonDict] = None + ) -> Union[JsonDict, list]: self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, @@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): ) channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( @@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): self.http_client2.post_json.side_effect = post_json - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a key belonging to a random server""" # make up a key to be fetched. testkey = signedjson.key.generate_signing_key("abc") @@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_key(self): + def test_get_notary_key(self) -> None: """Fetch a key belonging to the notary server""" # make up a key to be fetched. We randomise the keyid to try to get it to # appear before the key server signing key sometimes (otherwise we bail out @@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_keyserver_key(self): + def test_get_notary_keyserver_key(self) -> None: """Fetch the notary's keyserver key""" # we expect hs1 to make a regular key request to itself self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index f761e23f1b..c73179151a 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", } - def tests(self): + def tests(self) -> None: for hdr, expected in self.TEST_CASES.items(): res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, - "expected output for %s to be %s but was %s" % (hdr, expected, res), + f"expected output for {hdr!r} to be {expected} but was {res}", ) diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 913bc530aa..43e6f0f70a 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -21,12 +21,12 @@ from tests import unittest class MediaFilePathsTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.filepaths = MediaFilePaths("/media_store") - def test_local_media_filepath(self): + def test_local_media_filepath(self) -> None: """Test local media paths""" self.assertEqual( self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_local_media_thumbnail(self): + def test_local_media_thumbnail(self) -> None: """Test local media thumbnail paths""" self.assertEqual( self.filepaths.local_media_thumbnail_rel( @@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_local_media_thumbnail_dir(self): + def test_local_media_thumbnail_dir(self) -> None: """Test local media thumbnail directory paths""" self.assertEqual( self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_filepath(self): + def test_remote_media_filepath(self) -> None: """Test remote media paths""" self.assertEqual( self.filepaths.remote_media_filepath_rel( @@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_thumbnail(self): + def test_remote_media_thumbnail(self) -> None: """Test remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel( @@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_remote_media_thumbnail_legacy(self): + def test_remote_media_thumbnail_legacy(self) -> None: """Test old-style remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel_legacy( @@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", ) - def test_remote_media_thumbnail_dir(self): + def test_remote_media_thumbnail_dir(self) -> None: """Test remote media thumbnail directory paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_dir( @@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath(self): + def test_url_cache_filepath(self) -> None: """Test URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), @@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_filepath_legacy(self): + def test_url_cache_filepath_legacy(self) -> None: """Test old-style URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath_dirs_to_delete(self): + def test_url_cache_filepath_dirs_to_delete(self) -> None: """Test URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ["/media_store/url_cache/2020-01-02"], ) - def test_url_cache_filepath_dirs_to_delete_legacy(self): + def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_url_cache_thumbnail(self): + def test_url_cache_thumbnail(self) -> None: """Test URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_legacy(self): + def test_url_cache_thumbnail_legacy(self) -> None: """Test old-style URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_directory(self): + def test_url_cache_thumbnail_directory(self) -> None: """Test URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_thumbnail_directory_legacy(self): + def test_url_cache_thumbnail_directory_legacy(self) -> None: """Test old-style URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_thumbnail_dirs_to_delete(self): + def test_url_cache_thumbnail_dirs_to_delete(self) -> None: """Test URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_url_cache_thumbnail_dirs_to_delete_legacy(self): + def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_server_name_validation(self): + def test_server_name_validation(self) -> None: """Test validation of server names""" self._test_path_validation( [ @@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_file_id_validation(self): + def test_file_id_validation(self) -> None: """Test validation of local, remote and legacy URL cache file / media IDs""" # File / media IDs get split into three parts to form paths, consisting of the # first two characters, next two characters and rest of the ID. @@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase): invalid_values=invalid_file_ids, ) - def test_url_cache_media_id_validation(self): + def test_url_cache_media_id_validation(self) -> None: """Test validation of URL cache media IDs""" self._test_path_validation( [ @@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_content_type_validation(self): + def test_content_type_validation(self) -> None: """Test validation of thumbnail content types""" self._test_path_validation( [ @@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_thumbnail_method_validation(self): + def test_thumbnail_method_validation(self) -> None: """Test validation of thumbnail methods""" self._test_path_validation( [ @@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase): parameter: str, valid_values: Iterable[str], invalid_values: Iterable[str], - ): + ) -> None: """Test that the specified methods validate the named parameter as expected Args: diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index a4b57e3d1f..3fb37a2a59 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_long_summarize(self): + def test_long_summarize(self) -> None: example_paras = [ """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in @@ -90,7 +90,7 @@ class SummarizeTestCase(unittest.TestCase): " Tromsøya had a population of 36,088. Substantial parts of the urban…", ) - def test_short_summarize(self): + def test_short_summarize(self) -> None: example_paras = [ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -117,7 +117,7 @@ class SummarizeTestCase(unittest.TestCase): " most of the year.", ) - def test_small_then_large_summarize(self): + def test_small_then_large_summarize(self) -> None: example_paras = [ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_simple(self): + def test_simple(self) -> None: html = b""" Foo @@ -165,7 +165,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment(self): + def test_comment(self) -> None: html = b""" Foo @@ -181,7 +181,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment2(self): + def test_comment2(self) -> None: html = b""" Foo @@ -206,7 +206,7 @@ class CalcOgTestCase(unittest.TestCase): }, ) - def test_script(self): + def test_script(self) -> None: html = b""" Foo @@ -222,7 +222,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_missing_title(self): + def test_missing_title(self) -> None: html = b""" @@ -236,7 +236,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_h1_as_title(self): + def test_h1_as_title(self) -> None: html = b""" @@ -251,7 +251,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - def test_missing_title_and_broken_h1(self): + def test_missing_title_and_broken_h1(self) -> None: html = b""" @@ -266,19 +266,19 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_empty(self): + def test_empty(self) -> None: """Test a body with no data in it.""" html = b"" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_no_tree(self): + def test_no_tree(self) -> None: """A valid body with no tree in it.""" html = b"\x00" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_xml(self): + def test_xml(self) -> None: """Test decoding XML and ensure it works properly.""" # Note that the strip() call is important to ensure the xml tag starts # at the initial byte. @@ -293,7 +293,7 @@ class CalcOgTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding(self): + def test_invalid_encoding(self) -> None: """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" html = b""" @@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding2(self): + def test_invalid_encoding2(self) -> None: """A body which doesn't match the sent character encoding.""" # Note that this contains an invalid UTF-8 sequence in the title. html = b""" @@ -322,7 +322,7 @@ class CalcOgTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) - def test_windows_1252(self): + def test_windows_1252(self) -> None: """A body which uses cp1252, but doesn't declare that.""" html = b""" @@ -338,7 +338,7 @@ class CalcOgTestCase(unittest.TestCase): class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self): + def test_meta_charset(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -363,7 +363,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_charset_underscores(self): + def test_meta_charset_underscores(self) -> None: """A character encoding contains underscore.""" encodings = _get_html_media_encodings( b""" @@ -376,7 +376,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - def test_xml_encoding(self): + def test_xml_encoding(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -388,7 +388,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_xml_encoding(self): + def test_meta_xml_encoding(self) -> None: """Meta tags take precedence over XML encoding.""" encodings = _get_html_media_encodings( b""" @@ -402,7 +402,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - def test_content_type(self): + def test_content_type(self) -> None: """A character encoding is found via the Content-Type header.""" # Test a few variations of the header. headers = ( @@ -417,12 +417,12 @@ class MediaEncodingTestCase(unittest.TestCase): encodings = _get_html_media_encodings(b"", header) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_fallback(self): + def test_fallback(self) -> None: """A character encoding cannot be found in the body or header.""" encodings = _get_html_media_encodings(b"", "text/html") self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_duplicates(self): + def test_duplicates(self) -> None: """Ensure each encoding is only attempted once.""" encodings = _get_html_media_encodings( b""" @@ -436,7 +436,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_unknown_invalid(self): + def test_unknown_invalid(self) -> None: """A character encoding should be ignored if it is unknown or invalid.""" encodings = _get_html_media_encodings( b""" @@ -451,7 +451,7 @@ class MediaEncodingTestCase(unittest.TestCase): class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self): + def test_relative(self) -> None: """Relative URLs should be resolved based on the context of the base URL.""" self.assertEqual( rebase_url("subpage", "https://example.com/foo/"), @@ -466,14 +466,14 @@ class RebaseUrlTestCase(unittest.TestCase): "https://example.com/bar", ) - def test_absolute(self): + def test_absolute(self) -> None: """Absolute URLs should not be modified.""" self.assertEqual( rebase_url("https://alice.com/a/", "https://example.com/foo/"), "https://alice.com/a/", ) - def test_data(self): + def test_data(self) -> None: """Data URLs should not be modified.""" self.assertEqual( rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index 048d0ca44a..f38d7225f8 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -16,7 +16,7 @@ import json from twisted.test.proto_helpers import MemoryReactor -from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase class OEmbedTests(HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): - self.oembed = OEmbedProvider(homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.oembed = OEmbedProvider(hs) - def parse_response(self, response: JsonDict): + def parse_response(self, response: JsonDict) -> OEmbedResult: return self.oembed.parse_oembed_response( "https://test", json.dumps(response).encode("utf-8") ) - def test_version(self): + def test_version(self) -> None: """Accept versions that are similar to 1.0 as a string or int (or missing).""" for version in ("1.0", 1.0, 1): result = self.parse_response({"version": version, "type": "link"}) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 01d48c3860..da325955f8 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from synapse.rest.health import HealthResource @@ -19,12 +19,12 @@ from tests import unittest class HealthCheckTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> HealthResource: # replace the JsonResource with a HealthResource. return HealthResource() - def test_health(self): + def test_health(self) -> None: channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 118aa93a32..11f78f52b8 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource @@ -19,7 +21,7 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> Resource: # replace the JsonResource with a Resource wrapping the WellKnownResource res = Resource() res.putChild(b".well-known", well_known_resource(self.hs)) @@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "default_identity_server": "https://testis", } ) - def test_client_well_known(self): + def test_client_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, { @@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase): "public_baseurl": None, } ) - def test_client_well_known_no_public_baseurl(self): + def test_client_well_known_no_public_baseurl(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @unittest.override_config({"serve_server_wellknown": True}) - def test_server_well_known(self): + def test_server_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, {"m.server": "test:443"}, ) - def test_server_well_known_disabled(self): + def test_server_well_known_disabled(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) -- cgit 1.5.1 From cd1ae3d0b438ff453b7d4750c4fe901f266fcbb6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Mar 2022 07:10:10 -0500 Subject: Remove backwards compatibility with RelationPaginationToken. (#12138) --- changelog.d/12138.removal | 1 + synapse/rest/client/relations.py | 55 +++++++--------------------- synapse/storage/relations.py | 31 ---------------- tests/rest/client/test_relations.py | 73 +------------------------------------ 4 files changed, 16 insertions(+), 144 deletions(-) create mode 100644 changelog.d/12138.removal (limited to 'tests/rest') diff --git a/changelog.d/12138.removal b/changelog.d/12138.removal new file mode 100644 index 0000000000..6ed84d476c --- /dev/null +++ b/changelog.d/12138.removal @@ -0,0 +1 @@ +Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 487ea38b55..07fa1cdd4c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,50 +27,15 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -async def _parse_token( - store: "DataStore", token: Optional[str] -) -> Optional[StreamToken]: - """ - For backwards compatibility support RelationPaginationToken, but new pagination - tokens are generated as full StreamTokens, to be compatible with /sync and /messages. - """ - if not token: - return None - # Luckily the format for StreamToken and RelationPaginationToken differ enough - # that they can easily be separated. An "_" appears in the serialization of - # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses - # "-" only for separators. - if "_" in token: - return await StreamToken.from_string(store, token) - else: - relation_token = RelationPaginationToken.from_string(token) - return StreamToken( - room_key=RoomStreamToken(relation_token.topological, relation_token.stream), - presence_key=0, - typing_key=0, - receipt_key=0, - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=0, - groups_key=0, - ) - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -122,8 +87,12 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -317,8 +286,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 36ca2b8273..fba270150b 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -54,37 +54,6 @@ class PaginationChunk: return d -@attr.s(frozen=True, slots=True, auto_attribs=True) -class RelationPaginationToken: - """Pagination token for relation pagination API. - - As the results are in topological order, we can use the - `topological_ordering` and `stream_ordering` fields of the events at the - boundaries of the chunk as pagination tokens. - - Attributes: - topological: The topological ordering of the boundary event - stream: The stream ordering of the boundary event. - """ - - topological: int - stream: int - - @staticmethod - def from_string(string: str) -> "RelationPaginationToken": - try: - t, s = string.split("-") - return RelationPaginationToken(int(t), int(s)) - except ValueError: - raise SynapseError(400, "Invalid relation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.topological, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) - - @attr.s(frozen=True, slots=True, auto_attribs=True) class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 53062b41de..274f9c44c1 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -24,8 +24,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync from synapse.server import HomeServer -from synapse.storage.relations import RelationPaginationToken -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -281,15 +280,6 @@ class RelationsTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - def _stream_token_to_relation_token(self, token: str) -> str: - """Convert a StreamToken into a legacy token (RelationPaginationToken).""" - room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key - return self.get_success( - RelationPaginationToken( - topological=room_key.topological, stream=room_key.stream - ).to_string(self.store) - ) - def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -330,34 +320,6 @@ class RelationsTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") @@ -543,39 +505,6 @@ class RelationsTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" -- cgit 1.5.1 From f63bedef07360216a8de71dc38f00f1aea503903 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Mar 2022 09:00:05 -0500 Subject: Invalidate caches when an event with a relation is redacted. (#12121) The caches for the target of the relation must be cleared so that the bundled aggregations are re-calculated after the redaction is processed. --- changelog.d/12113.bugfix | 1 + changelog.d/12113.misc | 1 - changelog.d/12121.bugfix | 1 + synapse/storage/databases/main/cache.py | 2 + synapse/storage/databases/main/events.py | 38 +++++- tests/rest/client/test_relations.py | 207 ++++++++++++++++++++++++------- 6 files changed, 202 insertions(+), 48 deletions(-) create mode 100644 changelog.d/12113.bugfix delete mode 100644 changelog.d/12113.misc create mode 100644 changelog.d/12121.bugfix (limited to 'tests/rest') diff --git a/changelog.d/12113.bugfix b/changelog.d/12113.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12113.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc deleted file mode 100644 index 102e064053..0000000000 --- a/changelog.d/12113.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the tests for event relations. diff --git a/changelog.d/12121.bugfix b/changelog.d/12121.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12121.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c428dd5596..abd54c7dc7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -200,6 +200,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_relations_for_event.invalidate((relates_to,)) self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) + self.get_thread_summary.invalidate((relates_to,)) + self.get_thread_participated.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ca2a9ba9d1..1dc83aa5e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1518,7 +1518,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redaction(txn, event.redacts) + self._handle_redact_relations(txn, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1943,15 +1943,43 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. + def _handle_redact_relations( + self, txn: LoggingTransaction, redacted_event_id: str + ) -> None: + """Handles receiving a redaction and checking whether the redacted event + has any relations which must be removed from the database. Args: txn - redacted_event_id (str): The event that was redacted. + redacted_event_id: The event that was redacted. """ + # Fetch the current relation of the event being redacted. + redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_relations", + keyvalues={"event_id": redacted_event_id}, + retcol="relates_to_id", + allow_none=True, + ) + # Any relation information for the related event must be cleared. + if redacted_relates_to is not None: + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_applicable_edit, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_summary, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_participated, (redacted_relates_to,) + ) + self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 274f9c44c1..a40a5de399 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1273,7 +1273,21 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationRedactionTestCase(BaseRelationsTestCase): - """Test the behaviour of relations when the parent or child event is redacted.""" + """ + Test the behaviour of relations when the parent or child event is redacted. + + The behaviour of each relation type is subtly different which causes the tests + to be a bit repetitive, they follow a naming scheme of: + + test_redact_(relation|parent)_{relation_type} + + The first bit of "relation" means that the event with the relation defined + on it (the child event) is to be redacted. A "parent" means that the target + of the relation (the parent event) is to be redacted. + + The relation_type describes which type of relation is under test (i.e. it is + related to the value of rel_type in the event content). + """ def _redact(self, event_id: str) -> None: channel = self.make_request( @@ -1284,9 +1298,53 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: + """ + Makes requests and ensures they result in a 200 response, returns a + tuple of results: + + 1. `/relations` -> Returns a list of event IDs. + 2. `/event` -> Returns the response's m.relations field (from unsigned), + if it exists. + """ + + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) + + return event_ids, bundled_relations + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + def test_redact_relation_annotation(self) -> None: - """Test that annotations of an event are properly handled after the + """ + Test that annotations of an event are properly handled after the annotation is redacted. + + The redacted relation should not be included in bundled aggregations or + the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(200, channel.code, channel.json_body) @@ -1296,24 +1354,97 @@ class RelationRedactionTestCase(BaseRelationsTestCase): RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + + # Both relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, + ) + + # Both relations appear in the aggregation. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) # Redact one of the reactions. self._redact(to_redact_event_id) - # Ensure that the aggregations are correct. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + # The unredacted aggregation should still exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_relation_thread(self) -> None: + """ + Test that thread replies are properly handled after the thread reply redacted. + + The redacted event should not be included in bundled aggregations or + the response to relations. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + # Note that the *last* event in the thread is redacted, as that gets + # included in the bundled aggregation. + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 2", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + # Both relations exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 2, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event returned is the event that will be redacted. self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + to_redact_event_id, ) - def test_redact_relation_edit(self) -> None: + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event is now the unredacted event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + unredacted_event_id, + ) + + def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1331,34 +1462,19 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.REPLACE, relations) # Redact the original event self._redact(self.parent_id) - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 0) + self.assertEqual(relations, {}) - def test_redact_parent(self) -> None: + def test_redact_parent_annotation(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1366,16 +1482,23 @@ class RelationRedactionTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + # The relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.ANNOTATION, relations) + + # The aggregation should exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) + # Redact the original event. self._redact(self.parent_id) - # Check that aggregations returns zero - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(event_ids, []) + self.assertEqual(relations, {}) - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # There's nothing to aggregate. + chunk = self._get_aggregations() + self.assertEqual(chunk, []) -- cgit 1.5.1 From 26211fec24d8d0a967de33147e148166359ec8cb Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Mar 2022 09:44:33 -0800 Subject: Fix a bug in background updates wherein background updates are never run using the default batch size (#12157) --- changelog.d/12157.bugfix | 1 + synapse/storage/background_updates.py | 8 +++++--- tests/rest/admin/test_background_updates.py | 18 ++++++++---------- tests/storage/test_background_update.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/12157.bugfix (limited to 'tests/rest') diff --git a/changelog.d/12157.bugfix b/changelog.d/12157.bugfix new file mode 100644 index 0000000000..c3d2e700bb --- /dev/null +++ b/changelog.d/12157.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d64910aded..4acc2c997d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -102,10 +102,12 @@ class BackgroundUpdatePerformance: Returns: A duration in ms as a float """ - if self.avg_duration_ms == 0: - return 0 - elif self.total_item_count == 0: + # We want to return None if this is the first background update item + if self.total_item_count == 0: return None + # Avoid dividing by zero + elif self.avg_duration_ms == 0: + return 0 else: # Use the exponential moving average so that we can adapt to # changes in how long the update process takes. diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index fb36aa9940..becec84524 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -155,10 +155,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -210,10 +210,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -239,10 +239,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -278,11 +278,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.05263157894736842, "total_duration_ms": 2000.0, - "total_item_count": ( - 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE - ), + "total_item_count": (110), } }, "enabled": True, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 39dcc094bd..9fdf54ea31 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -66,13 +66,13 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), - by=0.01, + by=0.02, ) self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update -- cgit 1.5.1 From 15382b1afad65366df13c3b9040b6fdfb1eccfca Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Wed, 9 Mar 2022 18:23:57 +0000 Subject: Add third_party module callbacks to check if a user can delete a room and deactivate a user (#12028) * Add check_can_deactivate_user * Add check_can_shutdown_rooms * Documentation * callbacks, not functions * Various suggested tweaks * Add tests for test_check_can_shutdown_room and test_check_can_deactivate_user * Update check_can_deactivate_user to not take a Requester * Fix check_can_shutdown_room docs * Renegade and use `by_admin` instead of `admin_user_id` * fix lint * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier * Update docs/modules/third_party_rules_callbacks.md Co-authored-by: Brendan Abolivier Co-authored-by: Brendan Abolivier --- changelog.d/12028.feature | 1 + docs/modules/third_party_rules_callbacks.md | 43 ++++++++++ synapse/events/third_party_rules.py | 55 +++++++++++++ synapse/handlers/deactivate_account.py | 12 ++- synapse/handlers/room.py | 8 ++ synapse/module_api/__init__.py | 6 ++ synapse/rest/admin/rooms.py | 9 +++ tests/rest/client/test_third_party_rules.py | 121 ++++++++++++++++++++++++++++ 8 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12028.feature (limited to 'tests/rest') diff --git a/changelog.d/12028.feature b/changelog.d/12028.feature new file mode 100644 index 0000000000..5549c8f6fc --- /dev/null +++ b/changelog.d/12028.feature @@ -0,0 +1 @@ +Add third-party rules rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 09ac838107..1d3c39967f 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,49 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `check_can_shutdown_room` + +_First introduced in Synapse v1.55.0_ + +```python +async def check_can_shutdown_room( + user_id: str, room_id: str, +) -> bool: +``` + +Called when an admin user requests the shutdown of a room. The module must return a +boolean indicating whether the shutdown can go through. If the callback returns `False`, +the shutdown will not proceed and the caller will see a `M_FORBIDDEN` error. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + +### `check_can_deactivate_user` + +_First introduced in Synapse v1.55.0_ + +```python +async def check_can_deactivate_user( + user_id: str, by_admin: bool, +) -> bool: +``` + +Called when the deactivation of a user is requested. User deactivation can be +performed by an admin or the user themselves, so developers are encouraged to check the +requester when implementing this callback. The module must return a +boolean indicating whether the deactivation can go through. If the callback returns `False`, +the deactivation will not proceed and the caller will see a `M_FORBIDDEN` error. + +The module is passed two parameters, `user_id` which is the ID of the user being deactivated, and `by_admin` which is `True` if the request is made by a serve admin, and `False` otherwise. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + + ### `on_profile_update` _First introduced in Synapse v1.54.0_ diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index ede72ee876..bfca454f51 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -38,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] +CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] @@ -157,6 +159,12 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._check_can_shutdown_room_callbacks: List[ + CHECK_CAN_SHUTDOWN_ROOM_CALLBACK + ] = [] + self._check_can_deactivate_user_callbacks: List[ + CHECK_CAN_DEACTIVATE_USER_CALLBACK + ] = [] self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] self._on_user_deactivation_status_changed_callbacks: List[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -173,6 +181,8 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, + check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -198,6 +208,11 @@ class ThirdPartyEventRules: if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if check_can_shutdown_room is not None: + self._check_can_shutdown_room_callbacks.append(check_can_shutdown_room) + + if check_can_deactivate_user is not None: + self._check_can_deactivate_user_callbacks.append(check_can_deactivate_user) if on_profile_update is not None: self._on_profile_update_callbacks.append(on_profile_update) @@ -369,6 +384,46 @@ class ThirdPartyEventRules: "Failed to run module API callback %s: %s", callback, e ) + async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool: + """Intercept requests to shutdown a room. If `False` is returned, the + room must not be shut down. + + Args: + requester: The ID of the user requesting the shutdown. + room_id: The ID of the room. + """ + for callback in self._check_can_shutdown_room_callbacks: + try: + if await callback(user_id, room_id) is False: + return False + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + return True + + async def check_can_deactivate_user( + self, + user_id: str, + by_admin: bool, + ) -> bool: + """Intercept requests to deactivate a user. If `False` is returned, the + user should not be deactivated. + + Args: + requester + user_id: The ID of the room. + """ + for callback in self._check_can_deactivate_user_callbacks: + try: + if await callback(user_id, by_admin) is False: + return False + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + return True + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 76ae768e6e..816e1a6d79 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import Requester, UserID, create_requester +from synapse.types import Codes, Requester, UserID, create_requester if TYPE_CHECKING: from synapse.server import HomeServer @@ -42,6 +42,7 @@ class DeactivateAccountHandler: # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False + self._third_party_rules = hs.get_third_party_event_rules() # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). @@ -74,6 +75,15 @@ class DeactivateAccountHandler: Returns: True if identity server supports removing threepids, otherwise False. """ + + # Check if this user can be deactivated + if not await self._third_party_rules.check_can_deactivate_user( + user_id, by_admin + ): + raise SynapseError( + 403, "Deactivation of this user is forbidden", Codes.FORBIDDEN + ) + # FIXME: Theoretically there is a race here wherein user resets # password using threepid. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7b965b4b96..b9735631fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1475,6 +1475,7 @@ class RoomShutdownHandler: self.room_member_handler = hs.get_room_member_handler() self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() + self._third_party_rules = hs.get_third_party_event_rules() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main @@ -1548,6 +1549,13 @@ class RoomShutdownHandler: if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + if not await self._third_party_rules.check_can_shutdown_room( + requester_user_id, room_id + ): + raise SynapseError( + 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN + ) + # Action the block first (even if the room doesn't exist yet) if block: # This will work even if the room is already blocked, but that is diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c42eeedd87..d735c1d461 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -54,6 +54,8 @@ from synapse.events.spamcheck import ( USER_MAY_SEND_3PID_INVITE_CALLBACK, ) from synapse.events.third_party_rules import ( + CHECK_CAN_DEACTIVATE_USER_CALLBACK, + CHECK_CAN_SHUTDOWN_ROOM_CALLBACK, CHECK_EVENT_ALLOWED_CALLBACK, CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, @@ -283,6 +285,8 @@ class ModuleApi: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, + check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -298,6 +302,8 @@ class ModuleApi: check_threepid_can_be_invited=check_threepid_can_be_invited, check_visibility_can_be_modified=check_visibility_can_be_modified, on_new_event=on_new_event, + check_can_shutdown_room=check_can_shutdown_room, + check_can_deactivate_user=check_can_deactivate_user, on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f4736a3dad..356d6f74d7 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -67,6 +67,7 @@ class RoomRestV2Servlet(RestServlet): self._auth = hs.get_auth() self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() + self._third_party_rules = hs.get_third_party_event_rules() async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -106,6 +107,14 @@ class RoomRestV2Servlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) + # Check this here, as otherwise we'll only fail after the background job has been started. + if not await self._third_party_rules.check_can_shutdown_room( + requester.user.to_string(), room_id + ): + raise SynapseError( + 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN + ) + delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 58f1ea11b7..e7de67e3a3 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(args[0], user_id) self.assertFalse(args[1]) self.assertTrue(args[2]) + + def test_check_can_deactivate_user(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the request was not made by an admin + self.assertEqual(args[1], False) + + def test_check_can_deactivate_user_admin(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the mock was made by an admin + self.assertEqual(args[1], True) + + def test_check_can_shutdown_room(self) -> None: + """Tests that the check_can_shutdown_room module callback is called + correctly when processing an admin's shutdown room request. + """ + # Register a mocked callback. + shutdown_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_shutdown_room_callbacks.append( + shutdown_mock, + ) + + # Register an admin user. + admin_user_id = self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Shutdown the room. + channel = self.make_request( + "DELETE", + "/_synapse/admin/v2/rooms/%s" % self.room_id, + {}, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + shutdown_mock.assert_called_once() + args = shutdown_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], admin_user_id) + + # Check that the mock was called with the right room ID + self.assertEqual(args[1], self.room_id) -- cgit 1.5.1 From 88cd6f937807e64c05458cec86ef0ba0c1c656b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 09:03:59 -0500 Subject: Allow retrieving the relations of a redacted event. (#12130) This is allowed per MSC2675, although the original implementation did not allow for it and would return an empty chunk / not bundle aggregations. The main thing to improve is that the various caches get cleared properly when an event is redacted, and that edits must not leak if the original event is redacted (as that would presumably leak something similar to the original event content). --- changelog.d/12130.bugfix | 1 + changelog.d/12189.bugfix | 1 + changelog.d/12189.misc | 1 - synapse/rest/client/relations.py | 82 +++++++++++++---------------- synapse/storage/databases/main/cache.py | 4 ++ synapse/storage/databases/main/events.py | 11 ++-- synapse/storage/databases/main/relations.py | 60 +++++++++++---------- tests/rest/client/test_relations.py | 45 ++++++++++++++-- 8 files changed, 122 insertions(+), 83 deletions(-) create mode 100644 changelog.d/12130.bugfix create mode 100644 changelog.d/12189.bugfix delete mode 100644 changelog.d/12189.misc (limited to 'tests/rest') diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12130.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12189.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc deleted file mode 100644 index 015e808e63..0000000000 --- a/changelog.d/12189.misc +++ /dev/null @@ -1 +0,0 @@ -Support skipping some arguments when generating cache keys. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 07fa1cdd4c..d9a6be43f7 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,7 +27,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -82,28 +82,25 @@ class RelationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] @@ -193,27 +190,23 @@ class RelationAggregationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + room_id=room_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) return 200, await pagination_chunk.to_dict(self.store) @@ -295,6 +288,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index abd54c7dc7..d6a2df1afe 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -191,6 +191,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if redacts: self._invalidate_get_event_cache(redacts) + # Caches which might leak edits must be invalidated for the event being + # redacted. + self.get_relations_for_event.invalidate((redacts,)) + self.get_applicable_edit.invalidate((redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1dc83aa5e3..1a322882bf 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1619,9 +1619,12 @@ class PersistEventsStore: txn.call_after(prefill) - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: + # Invalidate the caches for the redacted event, note that these caches + # are also cleared as part of event replication in _invalidate_caches_for_event. txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) + txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) self.db_pool.simple_upsert_txn( txn, @@ -1812,9 +1815,7 @@ class PersistEventsStore: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) if rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (parent_id, event.room_id) - ) + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 36aa1092f6..be1500092b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore): self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore): List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ + # We don't use `event_id`, it's there so that we can cache based on + # it. The `event_id` must match the `event.event_id`. + assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event_id, room_id] + where_args: List[Union[str, int]] = [event.event_id, room_id] + is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: where_clause.append("relation_type = ?") @@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, topological_ordering, stream_ordering + SELECT event_id, relation_type, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore): last_stream_id = None events = [] for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] + # Do not include edits for redacted events as they leak event + # content. + if not is_redacted or row[1] != RelationTypes.REPLACE: + events.append({"event_id": row[0]}) + last_topo_id = row[2] + last_stream_id = row[3] # If there are more events, generate the next pagination key. next_token = None @@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore): ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self)) @@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ - # The already processed event IDs. Tracked separately from the result - # since the result omits events which do not have bundled aggregations. - seen_event_ids = set() - - # State events and redacted events do not get bundled aggregations. - events = [ - event - for event in events - if not event.is_state() and not event.internal_metadata.is_redacted() - ] + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} # Fetch other relations per event. - for event in events: - # De-duplicate events by ID to handle the same event requested multiple - # times. The caches that _get_bundled_aggregation_for_event use should - # capture this, but best to reduce work. - if event.event_id in seen_event_ids: - continue - seen_event_ids.add(event.event_id) - + for event in events_by_id.values(): event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result - # Fetch any edits. - edits = await self._get_applicable_edits(seen_event_ids) + # Fetch any edits (but not for redacted events). + edits = await self._get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. if self._msc3440_enabled: - summaries = await self._get_thread_summaries(seen_event_ids) + summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. participated = await self._get_threads_participated( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a40a5de399..f9ae6e663f 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1475,12 +1475,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(relations, {}) def test_redact_parent_annotation(self) -> None: - """Test that annotations of an event are redacted when the original event + """Test that annotations of an event are viewable when the original event is redacted. """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids, relations = self._make_relation_requests() @@ -1494,11 +1495,45 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # Redact the original event. self._redact(self.parent_id) - # The relations are not returned. + # The relations are returned. event_ids, relations = self._make_relation_requests() - self.assertEqual(event_ids, []) - self.assertEqual(relations, {}) + self.assertEquals(event_ids, [related_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + ) # There's nothing to aggregate. chunk = self._get_aggregations() - self.assertEqual(chunk, []) + self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_parent_thread(self) -> None: + """ + Test that thread replies are still available when the root event is redacted. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] + + # Redact one of the reactions. + self._redact(self.parent_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(len(event_ids), 1) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + related_event_id, + ) -- cgit 1.5.1 From ea27528b5d177dcfc5a4e38b463baeace916dc8e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 10:36:13 -0500 Subject: Support stable identifiers for MSC3440: Threading (#12151) The unstable identifiers are still supported if the experimental configuration flag is enabled. The unstable identifiers will be removed in a future release. --- changelog.d/12151.feature | 1 + synapse/api/constants.py | 4 +- synapse/api/filtering.py | 23 ++++----- synapse/events/utils.py | 9 +++- synapse/handlers/message.py | 5 +- synapse/rest/client/versions.py | 1 + synapse/server.py | 2 +- synapse/storage/databases/main/events.py | 5 +- synapse/storage/databases/main/relations.py | 77 ++++++++++++++++++----------- synapse/storage/databases/main/stream.py | 18 ++++--- tests/rest/client/test_relations.py | 7 +-- tests/rest/client/test_rooms.py | 18 +++---- tests/storage/test_stream.py | 20 ++++---- 13 files changed, 109 insertions(+), 81 deletions(-) create mode 100644 changelog.d/12151.feature (limited to 'tests/rest') diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature new file mode 100644 index 0000000000..18432b2da9 --- /dev/null +++ b/changelog.d/12151.feature @@ -0,0 +1 @@ +Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 36ace7c613..b0c08a074d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -178,7 +178,9 @@ class RelationTypes: ANNOTATION: Final = "m.annotation" REPLACE: Final = "m.replace" REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + THREAD: Final = "m.thread" + # TODO Remove this in Synapse >= v1.57.0. + UNSTABLE_THREAD: Final = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cb532d7238..27e97d6f37 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -88,7 +88,9 @@ ROOM_EVENT_FILTER_SCHEMA = { "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, # MSC3440, filtering by event relations. + "related_by_senders": {"type": "array", "items": {"type": "string"}}, "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "related_by_rel_types": {"type": "array", "items": {"type": "string"}}, "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -318,19 +320,18 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - # Ideally these would be rejected at the endpoint if they were provided - # and not supported, but that would involve modifying the JSON schema - # based on the homeserver configuration. + self.related_by_senders = self.filter_json.get("related_by_senders", None) + self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + + # Fallback to the unstable prefix if the stable version is not given. if hs.config.experimental.msc3440_enabled: - self.relation_senders = self.filter_json.get( + self.related_by_senders = self.related_by_senders or self.filter_json.get( "io.element.relation_senders", None ) - self.relation_types = self.filter_json.get( - "io.element.relation_types", None + self.related_by_rel_types = ( + self.related_by_rel_types + or self.filter_json.get("io.element.relation_types", None) ) - else: - self.relation_senders = None - self.relation_types = None def filters_all_types(self) -> bool: return "*" in self.not_types @@ -461,7 +462,7 @@ class Filter: event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( - event_ids, self.relation_senders, self.relation_types + event_ids, self.related_by_senders, self.related_by_rel_types ) ) @@ -474,7 +475,7 @@ class Filter: async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: result = [event for event in events if self._check(event)] - if self.relation_senders or self.relation_types: + if self.related_by_senders or self.related_by_rel_types: return await self._check_event_relations(result) return result diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ee34cb46e4..b2a237c1e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,6 +38,7 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.storage.databases.main.relations import BundledAggregations @@ -395,6 +396,9 @@ class EventClientSerializer: clients. """ + def __init__(self, hs: "HomeServer"): + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -515,11 +519,14 @@ class EventClientSerializer: thread.latest_event, serialized_latest_event, thread.latest_edit ) - serialized_aggregations[RelationTypes.THREAD] = { + thread_summary = { "latest_event": serialized_latest_event, "count": thread.count, "current_user_participated": thread.current_user_participated, } + serialized_aggregations[RelationTypes.THREAD] = thread_summary + if self._msc3440_enabled: + serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary # Include the bundled aggregations in the event. if serialized_aggregations: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0799ec9a84..f9544fe7fb 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1079,7 +1079,10 @@ class EventCreationHandler: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: + elif ( + relation_type == RelationTypes.THREAD + or relation_type == RelationTypes.UNSTABLE_THREAD + ): if await self.store.event_includes_relation(relates_to): raise SynapseError( 400, "Cannot start threads from an event with a relation" diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2e5d0e4e22..9a65aa4843 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,6 +101,7 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440": self.config.experimental.msc3440_enabled, + "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 1270abb5a3..7741ff29dc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -754,7 +754,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1a322882bf..1f60aef180 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1814,7 +1814,10 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - if rel_type == RelationTypes.THREAD: + if ( + rel_type == RelationTypes.THREAD + or rel_type == RelationTypes.UNSTABLE_THREAD + ): txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be1500092b..c4869d64e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -508,7 +508,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: @@ -523,16 +523,22 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). @@ -552,7 +558,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s GROUP BY parent.event_id """ @@ -561,9 +567,15 @@ class RelationsWorkerStore(SQLBaseStore): clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", latest_event_ids.keys() ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids @@ -626,16 +638,24 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s AND child.sender = ? """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.extend((RelationTypes.THREAD, user_id)) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + args.append(user_id) + + txn.execute(sql % (clause, relations_clause), args) return {row[0] for row in txn.fetchall()} participated_threads = await self.db_pool.runInteraction( @@ -834,26 +854,23 @@ class RelationsWorkerStore(SQLBaseStore): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. - if self._msc3440_enabled: - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - summaries.keys(), user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + summaries = await self._get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated(summaries.keys(), user_id) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a898f847e7..39e1efe373 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args.extend(event_filter.labels) # Filter on relation_senders / relation types from the joined tables. - if event_filter.relation_senders: + if event_filter.related_by_senders: clauses.append( "(%s)" % " OR ".join( - "related_event.sender = ?" for _ in event_filter.relation_senders + "related_event.sender = ?" for _ in event_filter.related_by_senders ) ) - args.extend(event_filter.relation_senders) + args.extend(event_filter.related_by_senders) - if event_filter.relation_types: + if event_filter.related_by_rel_types: clauses.append( "(%s)" - % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + % " OR ".join( + "relation_type = ?" for _ in event_filter.related_by_rel_types + ) ) - args.extend(event_filter.relation_types) + args.extend(event_filter.related_by_rel_types) return " AND ".join(clauses), args @@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and ( - event_filter.relation_senders or event_filter.relation_types + event_filter.related_by_senders or event_filter.related_by_rel_types ): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). @@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ - if event_filter.relation_senders: + if event_filter.related_by_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f9ae6e663f..0cbe6c0cf7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -547,9 +547,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -758,7 +756,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -1065,7 +1062,6 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1383,7 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): chunk = self._get_aggregations() self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330..3a9617d6da 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf33054..eaa0d7d749 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) -- cgit 1.5.1 From 483f2aa2eca98500046847364ede04b034530aac Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 11 Mar 2022 10:33:49 +0000 Subject: Retention test: avoid relying on state at purged events (#12202) This test was relying on poking events which weren't in the database into filter_events_for_client. --- changelog.d/12202.misc | 1 + tests/rest/client/test_retention.py | 29 +++++++++++++++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12202.misc (limited to 'tests/rest') diff --git a/changelog.d/12202.misc b/changelog.d/12202.misc new file mode 100644 index 0000000000..9f333e718a --- /dev/null +++ b/changelog.d/12202.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index f3bf8d0934..7b8fe6d025 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -24,6 +24,7 @@ from synapse.util import Clock from synapse.visibility import filter_events_for_client from tests import unittest +from tests.unittest import override_config one_hour_ms = 3600000 one_day_ms = one_hour_ms * 24 @@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["retention"] = { + + # merge this default retention config with anything that was specified in + # @override_config + retention_config = { "enabled": True, "default_policy": { "min_lifetime": one_day_ms, @@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase): "allowed_lifetime_min": one_day_ms, "allowed_lifetime_max": one_day_ms * 3, } + retention_config.update(config.get("retention", {})) + config["retention"] = retention_config self.hs = self.setup_test_homeserver(config=config) @@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 2) + @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}}) def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out - outdated events + outdated events, even if the purge job hasn't got to them yet. + + We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) - events = [] # Send a first event, which should be filtered out at the end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) - - # Get the event from the store so that we end up with a FrozenEvent that we can - # give to filter_events_for_client. We need to do this now because the event won't - # be in the database anymore after it has expired. - events.append(self.get_success(store.get_event(resp.get("event_id")))) + first_event_id = resp.get("event_id") # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send another event, which shouldn't get filtered out. resp = self.helper.send(room_id=room_id, body="2", tok=self.token) - valid_event_id = resp.get("event_id") - events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) - # Run filter_events_for_client with our list of FrozenEvents. + # Fetch the events, and run filter_events_for_client on them + events = self.get_success( + store.get_events_as_list([first_event_id, valid_event_id]) + ) + self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( filter_events_for_client(storage, self.user_id, events) ) -- cgit 1.5.1 From 32c828d0f760492711a98b11376e229d795fd1b3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 11 Mar 2022 13:42:22 +0100 Subject: Add type hints to `tests/rest`. (#12208) Co-authored-by: Patrick Cloke --- changelog.d/12208.misc | 1 + mypy.ini | 1 - tests/rest/client/test_transactions.py | 19 +++++- tests/rest/media/v1/test_media_storage.py | 110 ++++++++++++++++++------------ tests/rest/media/v1/test_url_preview.py | 83 +++++++++++----------- 5 files changed, 129 insertions(+), 85 deletions(-) create mode 100644 changelog.d/12208.misc (limited to 'tests/rest') diff --git a/changelog.d/12208.misc b/changelog.d/12208.misc new file mode 100644 index 0000000000..c5b6356799 --- /dev/null +++ b/changelog.d/12208.misc @@ -0,0 +1 @@ +Add type hints to tests files. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index c8390ddba9..f9c39fcaae 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,7 +90,6 @@ exclude = (?x) |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py |tests/rest/media/v1/test_media_storage.py - |tests/rest/media/v1/test_url_preview.py |tests/scripts/test_new_matrix_user.py |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 3b5747cb12..8d8251b2ac 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,3 +1,18 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus from unittest.mock import Mock, call from twisted.internet import defer, reactor @@ -11,14 +26,14 @@ from tests.utils import MockClock class HttpTransactionCacheTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.hs = Mock() self.hs.get_clock = Mock(return_value=self.clock) self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (200, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") self.mock_key = "foo" @defer.inlineCallbacks diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index cba9be17c4..7204b2dfe0 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -16,7 +16,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Optional +from typing import Any, BinaryIO, Dict, List, Optional, Union from unittest.mock import Mock from urllib import parse @@ -26,18 +26,24 @@ from PIL import Image as Image from twisted.internet import defer from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable +from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend +from synapse.server import HomeServer +from synapse.types import RoomAlias +from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request +from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import SMALL_PNG from tests.utils import default_config @@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): needs_threadpool = True - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") self.addCleanup(shutil.rmtree, self.test_dir) @@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): hs, self.primary_base_path, self.filepaths, storage_providers ) - def test_ensure_media_is_in_local_cache(self): + def test_ensure_media_is_in_local_cache(self) -> None: media_id = "some_media_id" test_body = "Test\n" @@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertEqual(test_body, body) -@attr.s(slots=True, frozen=True) +@attr.s(auto_attribs=True, slots=True, frozen=True) class _TestImage: """An image for testing thumbnailing with the expected results @@ -121,18 +127,18 @@ class _TestImage: a 404 is expected. """ - data = attr.ib(type=bytes) - content_type = attr.ib(type=bytes) - extension = attr.ib(type=bytes) - expected_cropped = attr.ib(type=Optional[bytes], default=None) - expected_scaled = attr.ib(type=Optional[bytes], default=None) - expected_found = attr.ib(default=True, type=bool) + data: bytes + content_type: bytes + extension: bytes + expected_cropped: Optional[bytes] = None + expected_scaled: Optional[bytes] = None + expected_found: bool = True @parameterized_class( ("test_image",), [ - # smoll png + # small png ( _TestImage( SMALL_PNG, @@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + def get_file( + destination: str, + path: str, + output_stream: BinaryIO, + args: Optional[Dict[str, Union[str, List[str]]]] = None, + max_size: Optional[int] = None, + ) -> Deferred: """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] @@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.media_id = "example.com/12345" - def _req(self, content_disposition, include_content_type=True): - + def _req( + self, content_disposition: Optional[bytes], include_content_type: bool = True + ) -> FakeChannel: channel = make_request( self.reactor, FakeSite(self.download_resource, self.reactor), @@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): return channel - def test_handle_missing_content_type(self): + def test_handle_missing_content_type(self) -> None: channel = self._req( b"inline; filename=out" + self.test_image.extension, include_content_type=False, @@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] ) - def test_disposition_filename_ascii(self): + def test_disposition_filename_ascii(self) -> None: """ If the filename is filename= then Synapse will decode it as an ASCII string, and use filename= in the response. @@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"inline; filename=out" + self.test_image.extension], ) - def test_disposition_filenamestar_utf8escaped(self): + def test_disposition_filenamestar_utf8escaped(self) -> None: """ If the filename is filename=*utf8'' then Synapse will correctly decode it as the UTF-8 string, and use filename* in the @@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"inline; filename*=utf-8''" + filename + self.test_image.extension], ) - def test_disposition_none(self): + def test_disposition_none(self) -> None: """ If there is no filename, one isn't passed on in the Content-Disposition of the request. @@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) - def test_thumbnail_crop(self): + def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) - def test_thumbnail_scale(self): + def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) - def test_invalid_type(self): + def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" self._test_thumbnail("invalid", None, False) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} ) - def test_no_thumbnail_crop(self): + def test_no_thumbnail_crop(self) -> None: """ Override the config to generate only scaled thumbnails, but request a cropped one. """ @@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase): @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} ) - def test_no_thumbnail_scale(self): + def test_no_thumbnail_scale(self) -> None: """ Override the config to generate only cropped thumbnails, but request a scaled one. """ self._test_thumbnail("scale", None, False) - def test_thumbnail_repeated_thumbnail(self): + def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ @@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.result["body"], ) - def _test_thumbnail(self, method, expected_body, expected_found): + def _test_thumbnail( + self, method: str, expected_body: Optional[bytes], expected_found: bool + ) -> None: params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) - def test_same_quality(self, method, desired_size): + def test_same_quality(self, method: str, desired_size: int) -> None: """Test that choosing between thumbnails with the same quality rating succeeds. We are not particular about which thumbnail is chosen.""" @@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) ) - def test_x_robots_tag_header(self): + def test_x_robots_tag_header(self) -> None: """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers to not index, archive, or follow links in media. @@ -540,29 +555,38 @@ class TestSpamChecker: `evil`. """ - def __init__(self, config, api): + def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: self.config = config self.api = api - def parse_config(config): + def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config - async def check_event_for_spam(self, foo): + async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]: return False # allow all events - async def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite( + self, + inviter_userid: str, + invitee_userid: str, + room_id: str, + ) -> bool: return True # allow all invites - async def user_may_create_room(self, userid): + async def user_may_create_room(self, userid: str) -> bool: return True # allow all room creations - async def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias( + self, userid: str, room_alias: RoomAlias + ) -> bool: return True # allow all room aliases - async def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: return True # allow publishing of all rooms - async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> bool: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) @@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") @@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): load_legacy_spam_checkers(hs) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = default_config("test") config.update( @@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): return config - def test_upload_innocent(self): + def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" self.helper.upload_media( self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 ) - def test_upload_ban(self): + def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should get rejected by the spam checker. """ diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index da2c533260..5148c39874 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -16,16 +16,21 @@ import base64 import json import os import re +from typing import Any, Dict, Optional, Sequence, Tuple, Type from urllib.parse import urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError -from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.internet.interfaces import IAddress, IResolutionReceiver +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from synapse.config.oembed import OEmbedEndpointConfig +from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest @@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"" ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["url_preview_enabled"] = True @@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] - self.lookups = {} + self.lookups: Dict[str, Any] = {} class Resolver: def resolveHostName( _self, - resolutionReceiver, - hostName, - portNumber=0, - addressTypes=None, - transportSemantics="TCP", - ): + resolutionReceiver: IResolutionReceiver, + hostName: str, + portNumber: int = 0, + addressTypes: Optional[Sequence[Type[IAddress]]] = None, + transportSemantics: str = "TCP", + ) -> IResolutionReceiver: resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) @@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): resolutionReceiver.resolutionComplete() return resolutionReceiver - self.reactor.nameResolver = Resolver() + self.reactor.nameResolver = Resolver() # type: ignore[assignment] - def create_test_resource(self): + def create_test_resource(self) -> MediaRepositoryResource: return self.hs.get_media_repository_resource() def _assert_small_png(self, json_body: JsonDict) -> None: @@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(json_body["og:image:type"], "image/png") self.assertEqual(json_body["matrix:image:size"], 67) - def test_cache_returns_correct_type(self): + def test_cache_returns_correct_type(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] channel = self.make_request( @@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_non_ascii_preview_httpequiv(self): + def test_non_ascii_preview_httpequiv(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - def test_video_rejected(self): + def test_video_rejected(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = b"anything" @@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_audio_rejected(self): + def test_audio_rejected(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = b"anything" @@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_non_ascii_preview_content_type(self): + def test_non_ascii_preview_content_type(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - def test_overlong_title(self): + def test_overlong_title(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # We should only see the `og:description` field, as `title` is too long and should be stripped out self.assertCountEqual(["og:description"], res.keys()) - def test_ipaddr(self): + def test_ipaddr(self) -> None: """ IP addresses can be previewed directly. """ @@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_specific(self): + def test_blacklisted_ip_specific(self) -> None: """ Blacklisted IP addresses, found via DNS, are not spidered. """ @@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_range(self): + def test_blacklisted_ip_range(self) -> None: """ Blacklisted IP ranges, IPs found over DNS, are not spidered. """ @@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_specific_direct(self): + def test_blacklisted_ip_specific_direct(self) -> None: """ Blacklisted IP addresses, accessed directly, are not spidered. """ @@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 403) - def test_blacklisted_ip_range_direct(self): + def test_blacklisted_ip_range_direct(self) -> None: """ Blacklisted IP ranges, accessed directly, are not spidered. """ @@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_range_whitelisted_ip(self): + def test_blacklisted_ip_range_whitelisted_ip(self) -> None: """ Blacklisted but then subsequently whitelisted IP addresses can be spidered. @@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_with_external_ip(self): + def test_blacklisted_ip_with_external_ip(self) -> None: """ If a hostname resolves a blacklisted IP, even if there's a non-blacklisted one, it will be rejected. @@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ipv6_specific(self): + def test_blacklisted_ipv6_specific(self) -> None: """ Blacklisted IP addresses, found via DNS, are not spidered. """ @@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ipv6_range(self): + def test_blacklisted_ipv6_range(self) -> None: """ Blacklisted IP ranges, IPs found over DNS, are not spidered. """ @@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_OPTIONS(self): + def test_OPTIONS(self) -> None: """ OPTIONS returns the OPTIONS. """ @@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) - def test_accept_language_config_option(self): + def test_accept_language_config_option(self) -> None: """ Accept-Language header is sent to the remote server """ @@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) - def test_data_url(self): + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. """ @@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 500) - def test_inline_data_url(self): + def test_inline_data_url(self) -> None: """ An inline image (as a data URL) should be parsed properly. """ @@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self._assert_small_png(channel.json_body) - def test_oembed_photo(self): + def test_oembed_photo(self) -> None: """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") self._assert_small_png(body) - def test_oembed_rich(self): + def test_oembed_rich(self) -> None: """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_oembed_format(self): + def test_oembed_format(self) -> None: """Test an oEmbed endpoint which requires the format in the URL.""" self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")] @@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_oembed_autodiscovery(self): + def test_oembed_autodiscovery(self) -> None: """ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. 1. Request a preview of a URL which is not known to the oEmbed code. @@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self._assert_small_png(body) - def _download_image(self): + def _download_image(self) -> Tuple[str, str]: """Downloads an image into the URL cache. Returns: A (host, media_id) tuple representing the MXC URI of the image. @@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertIsNone(_port) return host, media_id - def test_storage_providers_exclude_files(self): + def test_storage_providers_exclude_files(self) -> None: """Test that files are not stored in or fetched from storage providers.""" host, media_id = self._download_image() @@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): "URL cache file was unexpectedly retrieved from a storage provider", ) - def test_storage_providers_exclude_thumbnails(self): + def test_storage_providers_exclude_thumbnails(self) -> None: """Test that thumbnails are not stored in or fetched from storage providers.""" host, media_id = self._download_image() @@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): "URL cache thumbnail was unexpectedly retrieved from a storage provider", ) - def test_cache_expiry(self): + def test_cache_expiry(self) -> None: """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" self.preview_url.clock = MockClock() -- cgit 1.5.1 From ef3619e61d84493d98470eb2a69131d15eb1166b Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 11 Mar 2022 10:46:45 -0800 Subject: Add config settings for background update parameters (#11980) --- changelog.d/11980.misc | 1 + docs/sample_config.yaml | 32 ++++ synapse/config/_base.pyi | 2 + synapse/config/background_updates.py | 68 ++++++++ synapse/config/homeserver.py | 2 + synapse/storage/background_updates.py | 39 +++-- tests/config/test_background_update.py | 58 +++++++ tests/rest/admin/test_background_updates.py | 9 +- tests/storage/test_background_update.py | 253 ++++++++++++++++++++++++++-- 9 files changed, 430 insertions(+), 34 deletions(-) create mode 100644 changelog.d/11980.misc create mode 100644 synapse/config/background_updates.py create mode 100644 tests/config/test_background_update.py (limited to 'tests/rest') diff --git a/changelog.d/11980.misc b/changelog.d/11980.misc new file mode 100644 index 0000000000..36e992e645 --- /dev/null +++ b/changelog.d/11980.misc @@ -0,0 +1 @@ +Add config settings for background update parameters. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d634fd8ff5..36c6c56e58 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2735,3 +2735,35 @@ redis: # Optional password if configured on the Redis instance # #password: + + +## Background Updates ## + +# Background updates are database updates that are run in the background in batches. +# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to +# sleep can all be configured. This is helpful to speed up or slow down the updates. +# +background_updates: + # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set + # a time to change the default. + # + #background_update_duration_ms: 500 + + # Whether to sleep between updates. Defaults to True. Uncomment to change the default. + # + #sleep_enabled: false + + # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment + # and set a duration to change the default. + # + #sleep_duration_ms: 300 + + # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and + # set a size to change the default. + # + #min_batch_size: 10 + + # The batch size to use for the first iteration of a new background update. The default is 100. + # Uncomment and set a size to change the default. + # + #default_batch_size: 50 diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 1eb5f5a68c..363d8b4554 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -19,6 +19,7 @@ from synapse.config import ( api, appservice, auth, + background_updates, cache, captcha, cas, @@ -113,6 +114,7 @@ class RootConfig: caches: cache.CacheConfig federation: federation.FederationConfig retention: retention.RetentionConfig + background_updates: background_updates.BackgroundUpdateConfig config_classes: List[Type["Config"]] = ... def __init__(self) -> None: ... diff --git a/synapse/config/background_updates.py b/synapse/config/background_updates.py new file mode 100644 index 0000000000..f6cdeacc4b --- /dev/null +++ b/synapse/config/background_updates.py @@ -0,0 +1,68 @@ +# Copyright 2022 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class BackgroundUpdateConfig(Config): + section = "background_updates" + + def generate_config_section(self, **kwargs) -> str: + return """\ + ## Background Updates ## + + # Background updates are database updates that are run in the background in batches. + # The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to + # sleep can all be configured. This is helpful to speed up or slow down the updates. + # + background_updates: + # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set + # a time to change the default. + # + #background_update_duration_ms: 500 + + # Whether to sleep between updates. Defaults to True. Uncomment to change the default. + # + #sleep_enabled: false + + # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment + # and set a duration to change the default. + # + #sleep_duration_ms: 300 + + # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and + # set a size to change the default. + # + #min_batch_size: 10 + + # The batch size to use for the first iteration of a new background update. The default is 100. + # Uncomment and set a size to change the default. + # + #default_batch_size: 50 + """ + + def read_config(self, config, **kwargs) -> None: + bg_update_config = config.get("background_updates") or {} + + self.update_duration_ms = bg_update_config.get( + "background_update_duration_ms", 100 + ) + + self.sleep_enabled = bg_update_config.get("sleep_enabled", True) + + self.sleep_duration_ms = bg_update_config.get("sleep_duration_ms", 1000) + + self.min_batch_size = bg_update_config.get("min_batch_size", 1) + + self.default_batch_size = bg_update_config.get("default_batch_size", 100) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 001605c265..a4ec706908 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -16,6 +16,7 @@ from .account_validity import AccountValidityConfig from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig +from .background_updates import BackgroundUpdateConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -99,4 +100,5 @@ class HomeServerConfig(RootConfig): WorkerConfig, RedisConfig, ExperimentalConfig, + BackgroundUpdateConfig, ] diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 4acc2c997d..08c6eabc6d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -60,18 +60,19 @@ class _BackgroundUpdateHandler: class _BackgroundUpdateContextManager: - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 - - def __init__(self, sleep: bool, clock: Clock): + def __init__( + self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int + ): self._sleep = sleep self._clock = clock + self._sleep_duration_ms = sleep_duration_ms + self._update_duration_ms = update_duration async def __aenter__(self) -> int: if self._sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) + await self._clock.sleep(self._sleep_duration_ms / 1000) - return self.BACKGROUND_UPDATE_DURATION_MS + return self._update_duration_ms async def __aexit__(self, *exc) -> None: pass @@ -133,9 +134,6 @@ class BackgroundUpdater: process and autotuning the batch size. """ - MINIMUM_BACKGROUND_BATCH_SIZE = 1 - DEFAULT_BACKGROUND_BATCH_SIZE = 100 - def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database @@ -160,6 +158,14 @@ class BackgroundUpdater: # enable/disable background updates via the admin API. self.enabled = True + self.minimum_background_batch_size = hs.config.background_updates.min_batch_size + self.default_background_batch_size = ( + hs.config.background_updates.default_batch_size + ) + self.update_duration_ms = hs.config.background_updates.update_duration_ms + self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms + self.sleep_enabled = hs.config.background_updates.sleep_enabled + def register_update_controller_callbacks( self, on_update: ON_UPDATE_CALLBACK, @@ -216,7 +222,9 @@ class BackgroundUpdater: if self._on_update_callback is not None: return self._on_update_callback(update_name, database_name, oneshot) - return _BackgroundUpdateContextManager(sleep, self._clock) + return _BackgroundUpdateContextManager( + sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms + ) async def _default_batch_size(self, update_name: str, database_name: str) -> int: """The batch size to use for the first iteration of a new background @@ -225,7 +233,7 @@ class BackgroundUpdater: if self._default_batch_size_callback is not None: return await self._default_batch_size_callback(update_name, database_name) - return self.DEFAULT_BACKGROUND_BATCH_SIZE + return self.default_background_batch_size async def _min_batch_size(self, update_name: str, database_name: str) -> int: """A lower bound on the batch size of a new background update. @@ -235,7 +243,7 @@ class BackgroundUpdater: if self._min_batch_size_callback is not None: return await self._min_batch_size_callback(update_name, database_name) - return self.MINIMUM_BACKGROUND_BATCH_SIZE + return self.minimum_background_batch_size def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -254,9 +262,12 @@ class BackgroundUpdater: if self.enabled: # if we start a new background update, not all updates are done. self._all_done = False - run_as_background_process("background_updates", self.run_background_updates) + sleep = self.sleep_enabled + run_as_background_process( + "background_updates", self.run_background_updates, sleep + ) - async def run_background_updates(self, sleep: bool = True) -> None: + async def run_background_updates(self, sleep: bool) -> None: if self._running or not self.enabled: return diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py new file mode 100644 index 0000000000..0c32c1ca29 --- /dev/null +++ b/tests/config/test_background_update.py @@ -0,0 +1,58 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml + +from synapse.storage.background_updates import BackgroundUpdater + +from tests.unittest import HomeserverTestCase, override_config + + +class BackgroundUpdateConfigTestCase(HomeserverTestCase): + # Tests that the default values in the config are correctly loaded. Note that the default + # values are loaded when the corresponding config options are commented out, which is why there isn't + # a config specified here. + def test_default_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 1) + self.assertEqual(background_updater.default_background_batch_size, 100) + self.assertEqual(background_updater.sleep_enabled, True) + self.assertEqual(background_updater.sleep_duration_ms, 1000) + self.assertEqual(background_updater.update_duration_ms, 100) + + # Tests that non-default values for the config options are properly picked up and passed on. + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 1000 + sleep_enabled: false + sleep_duration_ms: 600 + min_batch_size: 5 + default_batch_size: 50 + """ + ) + ) + def test_custom_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 5) + self.assertEqual(background_updater.default_background_batch_size, 50) + self.assertEqual(background_updater.sleep_enabled, False) + self.assertEqual(background_updater.sleep_duration_ms, 600) + self.assertEqual(background_updater.update_duration_ms, 1000) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index becec84524..6cf56b1e35 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") + self.updater = BackgroundUpdater(hs, self.store.db_pool) @parameterized.expand( [ @@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): """Test the status API works with a background update.""" # Create a new background update - self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() + self.reactor.pump([1.0, 1.0, 1.0]) channel = self.make_request( @@ -158,7 +159,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -213,7 +214,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -242,7 +243,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fdf54ea31..5cf18b690e 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -14,12 +14,15 @@ from unittest.mock import Mock +import yaml + from twisted.internet.defer import Deferred, ensureDeferred from synapse.storage.background_updates import BackgroundUpdater from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config class BackgroundUpdateTestCase(unittest.HomeserverTestCase): @@ -34,6 +37,19 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.updates.register_background_update_handler( "test_update", self.update_handler ) + self.store = self.hs.get_datastores().main + + async def update(self, progress, count): + duration_ms = 10 + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count def test_do_background_update(self): # the time we claim it takes to update one item when running the update @@ -42,27 +58,14 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastores().main self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) ) - # first step: make a bit of progress - async def update(progress, count): - await self.clock.sleep((count * duration_ms) / 1000) - progress = {"my_key": progress["my_key"] + 1} - await store.db_pool.runInteraction( - "update_progress", - self.updates._background_update_progress_txn, - "test_update", - progress, - ) - return count - - self.update_handler.side_effect = update + self.update_handler.side_effect = self.update self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), @@ -72,7 +75,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.default_background_batch_size ) # second step: complete the update @@ -99,6 +102,224 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(result) self.assertFalse(self.update_handler.called) + @override_config( + yaml.safe_load( + """ + background_updates: + default_batch_size: 20 + """ + ) + ) + def test_background_update_default_batch_set_by_config(self): + """ + Test that the background update is run with the default_batch_size set by the config + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.01, + ) + self.assertFalse(res) + + # on the first call, we should get run with the default background update size specified in the config + self.update_handler.assert_called_once_with({"my_key": 1}, 20) + + def test_background_update_default_sleep_behavior(self): + """ + Test default background update behavior, which is to sleep + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor less than the default sleep duration (1000ms) + self.reactor.pump([0.5]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past default sleep duration + self.reactor.pump([1]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_duration_ms: 500 + """ + ) + ) + def test_background_update_sleep_set_in_config(self): + """ + Test that changing the sleep time in the config changes how long it sleeps + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor less than the configured sleep duration (500ms) + self.reactor.pump([0.45]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past config sleep duration but less than default duration + self.reactor.pump([0.75]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_enabled: false + """ + ) + ) + def test_disabling_background_update_sleep(self): + """ + Test that disabling sleep in the config results in bg update not sleeping + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor very little + self.reactor.pump([0.025]) + # check that an update has run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 500 + """ + ) + ) + def test_background_update_duration_set_in_config(self): + """ + Test that the desired duration set in the config is used in determining batch size + """ + # Duration of one background update item + duration_ms = 10 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.02, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with 500ms as the + # desired duration + async def update(progress, count): + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, + 500 / duration_ms, + places=0, + ) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + + @override_config( + yaml.safe_load( + """ + background_updates: + min_batch_size: 5 + """ + ) + ) + def test_background_update_min_batch_set_in_config(self): + """ + Test that the minimum batch size set in the config is used + """ + # a very long-running individual update + duration_ms = 50 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + # Run the update with the long-running update item + async def update(progress, count): + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count + + self.update_handler.side_effect = update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=1, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with minimum batch size + # as the first items took a very long time + async def update(progress, count): + self.assertEqual(progress, {"my_key": 2}) + self.assertEqual(count, 5) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): -- cgit 1.5.1 From 4587b35929d22731644a11120a9e7d6a9c3bc304 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 07:21:36 -0400 Subject: Clean-up logic for rebasing URLs during URL preview. (#12219) By using urljoin from the standard library and reducing the number of places URLs are rebased. --- changelog.d/12219.misc | 1 + synapse/rest/media/v1/preview_html.py | 39 +------------------ synapse/rest/media/v1/preview_url_resource.py | 23 ++++++------ tests/rest/media/v1/test_html_preview.py | 54 ++++++--------------------- 4 files changed, 26 insertions(+), 91 deletions(-) create mode 100644 changelog.d/12219.misc (limited to 'tests/rest') diff --git a/changelog.d/12219.misc b/changelog.d/12219.misc new file mode 100644 index 0000000000..6079414092 --- /dev/null +++ b/changelog.d/12219.misc @@ -0,0 +1 @@ +Clean-up logic around rebasing URLs for URL image previews. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 872a9e72e8..4cc9c66fbe 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -16,7 +16,6 @@ import itertools import logging import re from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union -from urllib import parse as urlparse if TYPE_CHECKING: from lxml import etree @@ -144,9 +143,7 @@ def decode_body( return etree.fromstring(body, parser) -def parse_html_to_open_graph( - tree: "etree.Element", media_uri: str -) -> Dict[str, Optional[str]]: +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: """ Parse the HTML document into an Open Graph response. @@ -155,7 +152,6 @@ def parse_html_to_open_graph( Args: tree: The parsed HTML document. - media_url: The URI used to download the body. Returns: The Open Graph response as a dictionary. @@ -209,7 +205,7 @@ def parse_html_to_open_graph( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og["og:image"] = rebase_url(meta_image[0], media_uri) + og["og:image"] = meta_image[0] else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") @@ -320,37 +316,6 @@ def _iterate_over_text( ) -def rebase_url(url: str, base: str) -> str: - """ - Resolves a potentially relative `url` against an absolute `base` URL. - - For example: - - >>> rebase_url("subpage", "https://example.com/foo/") - 'https://example.com/foo/subpage' - >>> rebase_url("sibling", "https://example.com/foo") - 'https://example.com/sibling' - >>> rebase_url("/bar", "https://example.com/foo/") - 'https://example.com/bar' - >>> rebase_url("https://alice.com/a/", "https://example.com/foo/") - 'https://alice.com/a' - """ - base_parts = urlparse.urlparse(base) - # Convert the parsed URL to a list for (potential) modification. - url_parts = list(urlparse.urlparse(url)) - # Add a scheme, if one does not exist. - if not url_parts[0]: - url_parts[0] = base_parts.scheme or "http" - # Fix up the hostname, if this is not a data URL. - if url_parts[0] != "data" and not url_parts[1]: - url_parts[1] = base_parts.netloc - # If the path does not start with a /, nest it under the base path's last - # directory. - if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2] - return urlparse.urlunparse(url_parts) - - def summarize_paragraphs( text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 ) -> Optional[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 14ea88b240..d47af8ead6 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -22,7 +22,7 @@ import shutil import sys import traceback from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple -from urllib import parse as urlparse +from urllib.parse import urljoin, urlparse, urlsplit from urllib.request import urlopen import attr @@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import ( - decode_body, - parse_html_to_open_graph, - rebase_url, -) +from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred @@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource): ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. - url_tuple = urlparse.urlsplit(url) + url_tuple = urlsplit(url) for entry in self.url_preview_url_blacklist: match = True for attrib in entry: @@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource): # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. - og_from_html = parse_html_to_open_graph(tree, media_info.uri) + og_from_html = parse_html_to_open_graph(tree) # Compile the Open Graph response by using the scraped # information from the HTML and overlaying any information @@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource): if "og:image" not in og or not og["og:image"]: return + # The image URL from the HTML might be relative to the previewed page, + # convert it to an URL which can be requested directly. + image_url = og["og:image"] + url_parts = urlparse(image_url) + if url_parts.scheme != "data": + image_url = urljoin(media_info.uri, image_url) + # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url( - rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True - ) + image_info = await self._handle_url(image_url, user, allow_data_urls=True) if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 3fb37a2a59..62e308814d 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import ( _get_html_media_encodings, decode_body, parse_html_to_open_graph, - rebase_url, summarize_paragraphs, ) @@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase): FooSome text. """.strip() tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding(self) -> None: @@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding2(self) -> None: @@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) def test_windows_1252(self) -> None: @@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) @@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase): 'text/html; charset="invalid"', ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - -class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self) -> None: - """Relative URLs should be resolved based on the context of the base URL.""" - self.assertEqual( - rebase_url("subpage", "https://example.com/foo/"), - "https://example.com/foo/subpage", - ) - self.assertEqual( - rebase_url("sibling", "https://example.com/foo"), - "https://example.com/sibling", - ) - self.assertEqual( - rebase_url("/bar", "https://example.com/foo/"), - "https://example.com/bar", - ) - - def test_absolute(self) -> None: - """Absolute URLs should not be modified.""" - self.assertEqual( - rebase_url("https://alice.com/a/", "https://example.com/foo/"), - "https://alice.com/a/", - ) - - def test_data(self) -> None: - """Data URLs should not be modified.""" - self.assertEqual( - rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), - "data:,Hello%2C%20World%21", - ) -- cgit 1.5.1 From 1da0f79d5455b594f2aa989106a672786f5b990f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 09:20:57 -0400 Subject: Refactor relations tests (#12232) * Moves the relation pagination tests to a separate class. * Move the assertion of the response code into the `_send_relation` helper. * Moves some helpers into the base-class. --- changelog.d/12232.misc | 1 + tests/rest/client/test_relations.py | 1389 +++++++++++++++++------------------ 2 files changed, 674 insertions(+), 716 deletions(-) create mode 100644 changelog.d/12232.misc (limited to 'tests/rest') diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc new file mode 100644 index 0000000000..4a4132edff --- /dev/null +++ b/changelog.d/12232.misc @@ -0,0 +1 @@ +Refactor relations tests to improve code re-use. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0cbe6c0cf7..3dbd1304a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -79,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content: Optional[dict] = None, access_token: Optional[str] = None, parent_id: Optional[str] = None, + expected_response_code: int = 200, ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` @@ -115,16 +116,50 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content, access_token=access_token, ) + self.assertEqual(expected_response_code, channel.code, channel.json_body) return channel + def _get_related_events(self) -> List[str]: + """ + Requests /relations on the parent ID and returns a list of event IDs. + """ + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return [ev["event_id"] for ev in channel.json_body["chunk"]] + + def _get_bundled_aggregations(self) -> JsonDict: + """ + Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists. + """ + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return channel.json_body["unsigned"].get("m.relations", {}) + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: """Tests that sending a relation works.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] channel = self.make_request( @@ -151,13 +186,13 @@ class RelationsTestCase(BaseRelationsTestCase): def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -171,13 +206,12 @@ class RelationsTestCase(BaseRelationsTestCase): desc="test_deny_invalid_event", ) ) - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" @@ -187,18 +221,20 @@ class RelationsTestCase(BaseRelationsTestCase): parent_id = res["event_id"] # Attempt to send an annotation to that event. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + parent_id=parent_id, + key="A", + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(400, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400 + ) def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" @@ -208,461 +244,160 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, "m.room.message", content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) - def test_basic_paginate_relations(self) -> None: - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - first_annotation_id = channel.json_body["event_id"] + def test_aggregation(self) -> None: + """Test that annotations get correctly aggregated.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - second_annotation_id = channel.json_body["event_id"] + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self.make_request( "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - # We expect to get back a single pagination result, which is the latest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( + self.assertEqual( + channel.json_body, { - "event_id": second_annotation_id, - "sender": self.user_id, - "type": "m.reaction", + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] }, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id ) - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) + def test_aggregation_must_be_annotation(self) -> None: + """Test that aggregations must be annotations.""" - # Request the relations again, but with a different direction. channel = self.make_request( "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations" + f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", access_token=self.user_token, ) - self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - # We expect to get back a single pagination result, which is the earliest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": first_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_bundled_aggregations(self) -> None: + """ + Test that annotations, references, and threads get correctly bundled. - def test_repeated_paginate_relations(self) -> None: - """Test that if we paginate using a limit and tokens then we get the - expected events. + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + + See test_edit for a similar test for edits. """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - expected_event_ids = [] - for idx in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] - prev_token = "" - found_event_ids: List[str] = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.THREAD, "m.room.test") - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") - if not prev_token: - break + # Ensure the fields are as expected. + self.assertCountEqual( + relations_dict.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) + # Check the values of each field. + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + relations_dict[RelationTypes.ANNOTATION], + ) - def test_pagination_from_sync_and_messages(self) -> None: - """Pagination tokens from /sync and /messages can be used to paginate /relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - # Send an event after the relation events. - self.helper.send(self.room, body="Latest event", tok=self.user_token) + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + relations_dict[RelationTypes.REFERENCE], + ) - # Request /sync, limiting it such that only the latest event is returned - # (and not the relation). - filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') + self.assertEqual( + 2, + relations_dict[RelationTypes.THREAD].get("count"), + ) + self.assertTrue( + relations_dict[RelationTypes.THREAD].get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + relations_dict[RelationTypes.THREAD].get("latest_event"), + ) + + # Request the event directly. channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - sync_prev_batch = room_timeline["prev_batch"] - self.assertIsNotNone(sync_prev_batch) - # Ensure the relation event is not in the batch returned from /sync. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in room_timeline["events"]] - ) + assert_bundle(channel.json_body) - # Request /messages, limiting it such that only the latest event is - # returned (and not the relation). + # Request the room messages. channel = self.make_request( "GET", - f"/rooms/{self.room}/messages?dir=b&limit=1", + f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - messages_end = channel.json_body["end"] - self.assertIsNotNone(messages_end) - # Ensure the relation event is not in the chunk returned from /messages. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, ) - - # Request /relations with the pagination tokens received from both the - # /sync and /messages responses above, in turn. - # - # This is a tiny bit silly since the client wouldn't know the parent ID - # from the requests above; consider the parent ID to be known from a - # previous /sync. - for from_token in (sync_prev_batch, messages_end): - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # The relation should be in the returned chunk. - self.assertIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - def test_aggregation(self) -> None: - """Test that annotations get correctly aggregated.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_must_be_annotation(self) -> None: - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations" - f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(400, channel.code, channel.json_body) - - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) @@ -693,14 +428,12 @@ class RelationsTestCase(BaseRelationsTestCase): when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -713,14 +446,12 @@ class RelationsTestCase(BaseRelationsTestCase): def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -877,8 +608,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -954,7 +683,7 @@ class RelationsTestCase(BaseRelationsTestCase): shouldn't be allowed, are correctly handled. """ - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -963,309 +692,575 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + edit_event_id = channel.json_body["event_id"] + + self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_reply(self) -> None: + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{reply}", + access_token=self.user_token, + ) self.assertEqual(200, channel.code, channel.json_body) + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_thread(self) -> None: + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + threaded_event_id = channel.json_body["event_id"] + new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + + def test_edit_edit(self) -> None: + """Test that an edit cannot be edited.""" + new_body = {"msgtype": "m.text", "body": "Initial edit"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + }, + ) + edit_event_id = channel.json_body["event_id"] + + # Edit the edit event. + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, + }, + parent_id=edit_event_id, + ) + + # Request the original event. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + # The edit to the edit should be ignored. + self.assertEqual(channel.json_body["content"], new_body) + + # The relations information should not include the edit to the edit. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_unknown_relations(self) -> None: + """Unknown relations should be accepted.""" + channel = self._send_relation("m.relation.test", "m.room.test") + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # When bundling the unknown relation is not included. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + + # But unknown relations can be directly queried. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + + def test_background_update(self) -> None: + """Test the event_arbitrary_relations background update.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + annotation_event_id_good = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") + annotation_event_id_bad = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_event_id = channel.json_body["event_id"] + + # Clean-up the table as if the inserts did not happen during event creation. + self.get_success( + self.store.db_pool.simple_delete_many( + table="event_relations", + column="event_id", + iterable=(annotation_event_id_bad, thread_event_id), + keyvalues={}, + desc="RelationsTestCase.test_background_update", + ) + ) + + # Only the "good" annotation should be found. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", + access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good], + ) - edit_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message.WRONG_TYPE", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, - }, + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "event_arbitrary_relations", "progress_json": "{}"}, + ) ) - self.assertEqual(200, channel.code, channel.json_body) + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + self.wait_for_background_updates() + + # The "good" annotation and the thread should be found, but not the "bad" + # annotation. channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{self.parent_id}", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) + self.assertCountEqual( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good, thread_event_id], + ) - self.assertEqual(channel.json_body["content"], new_body) - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) +class RelationPaginationTestCase(BaseRelationsTestCase): + def test_basic_paginate_relations(self) -> None: + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + first_annotation_id = channel.json_body["event_id"] - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + second_annotation_id = channel.json_body["event_id"] - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, ) + self.assertEqual(200, channel.code, channel.json_body) - def test_edit_reply(self) -> None: - """Test that editing a reply works.""" - - # Create a reply to edit. - channel = self._send_relation( - RelationTypes.REFERENCE, - "m.room.message", - content={"msgtype": "m.text", "body": "A reply!"}, + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], ) - self.assertEqual(200, channel.code, channel.json_body) - reply = channel.json_body["event_id"] - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=reply, + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + # Request the relations again, but with a different direction. channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{reply}", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - # We expect to see the new body in the dict, as well as the reference - # metadata sill intact. - self.assertDictContainsSubset(new_body, channel.json_body["content"]) - self.assertDictContainsSubset( + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": "m.reference", - } + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", }, - channel.json_body["content"], + channel.json_body["chunk"][0], ) - # We expect that the edit relation appears in the unsigned relations - # section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + def test_repeated_paginate_relations(self) -> None: + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + expected_event_ids = [] + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) + expected_event_ids.append(channel.json_body["event_id"]) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) + prev_token = "" + found_event_ids: List[str] = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token - def test_edit_thread(self) -> None: - """Test that editing a thread works.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - # Create a thread and edit the last event. - channel = self._send_relation( - RelationTypes.THREAD, - "m.room.message", - content={"msgtype": "m.text", "body": "A threaded reply!"}, - ) - self.assertEqual(200, channel.code, channel.json_body) - threaded_event_id = channel.json_body["event_id"] + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=threaded_event_id, + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self) -> None: + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token ) self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] + ) - # Fetch the thread root, to get the bundled aggregation for the thread. + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{self.parent_id}", + f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) - # We expect that the edit message appears in the thread summary in the - # unsigned relations section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.THREAD, relations_dict) + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - thread_summary = relations_dict[RelationTypes.THREAD] - self.assertIn("latest_event", thread_summary) - latest_event_in_thread = thread_summary["latest_event"] - self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + def test_aggregation_pagination_groups(self) -> None: + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + + idx += 1 + idx %= len(access_tokens) - def test_edit_edit(self) -> None: - """Test that an edit cannot be edited.""" - new_body = {"msgtype": "m.text", "body": "Initial edit"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": new_body, - }, - ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] + prev_token: Optional[str] = None + found_groups: Dict[str, int] = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token - # Edit the edit event. - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "foo", - "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, - }, - parent_id=edit_event_id, - ) - self.assertEqual(200, channel.code, channel.json_body) + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - # Request the original event. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - # The edit to the edit should be ignored. - self.assertEqual(channel.json_body["content"], new_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - # The relations information should not include the edit to the edit. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) + found_groups[groups["key"]] = groups["count"] - def test_unknown_relations(self) -> None: - """Unknown relations should be accepted.""" - channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] + next_batch = channel.json_body.get("next_batch") - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch - # We expect to get back a single pagination result, which is the full - # relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, - channel.json_body["chunk"][0], - ) + if not prev_token: + break - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) + self.assertEqual(sent_groups, found_groups) - # When bundling the unknown relation is not included. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_aggregation_pagination_within_group(self) -> None: + """Test that we can paginate within an annotation group.""" - # But unknown relations can be directly queried. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) - raise AssertionError(f"Event {self.parent_id} not found in chunk") + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + expected_event_ids.append(channel.json_body["event_id"]) - def test_background_update(self) -> None: - """Test the event_arbitrary_relations background update.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) - annotation_event_id_good = channel.json_body["event_id"] + idx += 1 - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_event_id_bad = channel.json_body["event_id"] + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_event_id = channel.json_body["event_id"] + prev_token = "" + found_event_ids: List[str] = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token - # Clean-up the table as if the inserts did not happen during event creation. - self.get_success( - self.store.db_pool.simple_delete_many( - table="event_relations", - column="event_id", - iterable=(annotation_event_id_bad, thread_event_id), - keyvalues={}, - desc="RelationsTestCase.test_background_update", + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, ) - ) + self.assertEqual(200, channel.code, channel.json_body) - # Only the "good" annotation should be found. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - [ev["event_id"] for ev in channel.json_body["chunk"]], - [annotation_event_id_good], - ) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - # Insert and run the background update. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "event_arbitrary_relations", "progress_json": "{}"}, - ) - ) + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - # Ugh, have to reset this flag - self.store.db_pool.updates._all_done = False - self.wait_for_background_updates() + next_batch = channel.json_body.get("next_batch") - # The "good" annotation and the thread should be found, but not the "bad" - # annotation. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertCountEqual( - [ev["event_id"] for ev in channel.json_body["chunk"]], - [annotation_event_id_good, thread_event_id], - ) + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) class RelationRedactionTestCase(BaseRelationsTestCase): @@ -1294,46 +1289,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) - def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: - """ - Makes requests and ensures they result in a 200 response, returns a - tuple of results: - - 1. `/relations` -> Returns a list of event IDs. - 2. `/event` -> Returns the response's m.relations field (from unsigned), - if it exists. - """ - - # Request the relations of the event. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] - - # Fetch the bundled aggregations of the event. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) - - return event_ids, bundled_relations - - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1343,17 +1298,16 @@ class RelationRedactionTestCase(BaseRelationsTestCase): the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Both relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1368,7 +1322,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1391,7 +1346,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Note that the *last* event in the thread is redacted, as that gets @@ -1401,11 +1355,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 2", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] # Both relations exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( { @@ -1424,7 +1378,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( { @@ -1444,7 +1399,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): is redacted. """ # Add a relation - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", parent_id=self.parent_id, @@ -1454,10 +1409,10 @@ class RelationRedactionTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.REPLACE, relations) @@ -1465,7 +1420,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(self.parent_id) # The relations are not returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 0) self.assertEqual(relations, {}) @@ -1475,11 +1431,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # The relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) @@ -1491,7 +1447,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(self.parent_id) # The relations are returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( relations["m.annotation"], @@ -1512,14 +1469,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # Redact one of the reactions. self._redact(self.parent_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(len(event_ids), 1) self.assertDictContainsSubset( { -- cgit 1.5.1 From 96274565ff0dbb7d21b02b04fcef115330426707 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 12:17:39 -0400 Subject: Fix bundling aggregations if unsigned is not a returned event field. (#12234) An error occured if a filter was supplied with `event_fields` which did not include `unsigned`. In that case, bundled aggregations are still added as the spec states it is allowed for servers to add additional fields. --- changelog.d/12234.bugfix | 1 + synapse/events/utils.py | 9 ++++++--- tests/rest/client/test_relations.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12234.bugfix (limited to 'tests/rest') diff --git a/changelog.d/12234.bugfix b/changelog.d/12234.bugfix new file mode 100644 index 0000000000..dbb77f36ff --- /dev/null +++ b/changelog.d/12234.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when a `filter` argument with `event_fields` supplied but not including the `unsigned` field could result in a 500 error on `/sync`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b2a237c1e0..a0520068e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -530,9 +530,12 @@ class EventClientSerializer: # Include the bundled aggregations in the event. if serialized_aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - serialized_aggregations - ) + # There is likely already an "unsigned" field, but a filter might + # have stripped it off (via the event_fields option). The server is + # allowed to return additional fields, so add it back. + serialized_event.setdefault("unsigned", {}).setdefault( + "m.relations", {} + ).update(serialized_aggregations) def serialize_events( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0cbe6c0cf7..171f4e97c8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1267,6 +1267,34 @@ class RelationsTestCase(BaseRelationsTestCase): [annotation_event_id_good, thread_event_id], ) + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + class RelationRedactionTestCase(BaseRelationsTestCase): """ -- cgit 1.5.1 From 80e0e1f35e6b1cdfa0267f9c40a6f212b7d774de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:15:45 -0400 Subject: Only fetch thread participation for events with threads. (#12228) We fetch the thread summary in two phases: 1. The summary that is shared by all users (count of messages and latest event). 2. Whether the requesting user has participated in the thread. There's no use in attempting step 2 for events which did not return a summary from step 1. --- changelog.d/12228.bugfix | 1 + synapse/storage/databases/main/relations.py | 4 +- tests/rest/client/test_relations.py | 509 +++++++++++++++------------- tests/server.py | 20 +- 4 files changed, 289 insertions(+), 245 deletions(-) create mode 100644 changelog.d/12228.bugfix (limited to 'tests/rest') diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix new file mode 100644 index 0000000000..4755777139 --- /dev/null +++ b/changelog.d/12228.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c4869d64e6..af2334a65e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -857,7 +857,9 @@ class RelationsWorkerStore(SQLBaseStore): summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. - participated = await self._get_threads_participated(summaries.keys(), user_id) + participated = await self._get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) for event_id, summary in summaries.items(): if summary: thread_count, latest_thread_event, edit = summary diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f3741b3001..329690f8f7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -155,6 +155,16 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.json_body) return channel.json_body["chunk"] + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: @@ -291,202 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_2 = channel.json_body["event_id"] - - self._send_relation(RelationTypes.THREAD, "m.room.test") - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) - - # Request search. - channel = self.make_request( - "POST", - "/search", - # Search term matches the parent message. - content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - chunk = [ - result["result"] - for result in channel.json_body["search_categories"]["room_events"][ - "results" - ] - ] - assert_bundle(self._find_event_in_chunk(chunk)) - - def test_aggregation_get_event_for_annotation(self) -> None: - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - - def test_aggregation_get_event_for_thread(self) -> None: - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEqual( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -796,7 +610,7 @@ class RelationsTestCase(BaseRelationsTestCase): threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, @@ -836,7 +650,7 @@ class RelationsTestCase(BaseRelationsTestCase): edit_event_id = channel.json_body["event_id"] # Edit the edit event. - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -912,16 +726,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event - - raise AssertionError(f"Event {self.parent_id} not found in chunk") - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -981,34 +785,6 @@ class RelationsTestCase(BaseRelationsTestCase): [annotation_event_id_good, thread_event_id], ) - def test_bundled_aggregations_with_filter(self) -> None: - """ - If "unsigned" is an omitted field (due to filtering), adding the bundled - aggregations should not break. - - Note that the spec allows for a server to return additional fields beyond - what is specified. - """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - - # Note that the sync filter does not include "unsigned" as a field. - filter = urllib.parse.quote_plus( - b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' - ) - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Ensure the timeline is limited, find the parent event. - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - parent_event = self._find_event_in_chunk(room_timeline["events"]) - - # Ensure there's bundled aggregations on it. - self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) - class RelationPaginationTestCase(BaseRelationsTestCase): def test_basic_paginate_relations(self) -> None: @@ -1255,7 +1031,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): idx += 1 # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") prev_token = "" found_event_ids: List[str] = [] @@ -1291,6 +1067,263 @@ class RelationPaginationTestCase(BaseRelationsTestCase): self.assertEqual(found_event_ids, expected_event_ids) +class BundledAggregationsTestCase(BaseRelationsTestCase): + """ + See RelationsTestCase.test_edit for a similar test for edits. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + """ + + def _test_bundled_aggregations( + self, + relation_type: str, + assertion_callable: Callable[[JsonDict], None], + expected_db_txn_for_event: int, + ) -> None: + """ + Makes requests to various endpoints which should include bundled aggregations + and then calls an assertion function on the bundled aggregations. + + Args: + relation_type: The field to search for in the `m.relations` field in unsigned. + assertion_callable: Called with the contents of unsigned["m.relations"][relation_type] + for relation-specific assertions. + expected_db_txn_for_event: The number of database transactions which + are expected for a call to /event/. + """ + + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + + # Ensure the fields are as expected. + self.assertCountEqual(relations_dict.keys(), (relation_type,)) + assertion_callable(relations_dict[relation_type]) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) + + # Request sync. + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_annotation(self) -> None: + """ + Test that annotations get correctly bundled. + """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_reference(self) -> None: + """ + Test that references get correctly bundled. + """ + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_thread(self) -> None: + """ + Test that threads get correctly bundled. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + + def test_aggregation_get_event_for_annotation(self) -> None: + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self) -> None: + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEqual( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + + class RelationRedactionTestCase(BaseRelationsTestCase): """ Test the behaviour of relations when the parent or child event is redacted. diff --git a/tests/server.py b/tests/server.py index 82990c2eb9..6ce2a17bf4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -54,13 +54,18 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import ( + AccumulatingProtocol, + MemoryReactor, + MemoryReactorClock, +) from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -88,18 +93,19 @@ class TimedOutException(Exception): """ -@attr.s +@attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). """ - site = attr.ib(type=Union[Site, "FakeSite"]) - _reactor = attr.ib() - result = attr.ib(type=dict, default=attr.Factory(dict)) - _ip = attr.ib(type=str, default="127.0.0.1") + site: Union[Site, "FakeSite"] + _reactor: MemoryReactor + result: dict = attr.Factory(dict) + _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None + resource_usage: Optional[ContextResourceUsage] = None @property def json_body(self): @@ -168,6 +174,8 @@ class FakeChannel: def requestDone(self, _self): self.result["done"] = True + if isinstance(_self, SynapseRequest): + self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): # We give an address so that getClientIP returns a non null entry, -- cgit 1.5.1 From 516d092ff95d02c0bb2133c9316a1fb4ff2f5072 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 23 Mar 2022 12:19:20 +0100 Subject: Rename shared_rooms to mutual_rooms (#12036) Co-authored-by: reivilibre --- changelog.d/12036.misc | 1 + synapse/rest/__init__.py | 4 +- synapse/rest/client/mutual_rooms.py | 76 ++++++++++++ synapse/rest/client/shared_rooms.py | 75 ------------ synapse/storage/databases/main/user_directory.py | 6 +- tests/rest/client/test_mutual_rooms.py | 146 +++++++++++++++++++++++ tests/rest/client/test_shared_rooms.py | 146 ----------------------- 7 files changed, 228 insertions(+), 226 deletions(-) create mode 100644 changelog.d/12036.misc create mode 100644 synapse/rest/client/mutual_rooms.py delete mode 100644 synapse/rest/client/shared_rooms.py create mode 100644 tests/rest/client/test_mutual_rooms.py delete mode 100644 tests/rest/client/test_shared_rooms.py (limited to 'tests/rest') diff --git a/changelog.d/12036.misc b/changelog.d/12036.misc new file mode 100644 index 0000000000..d2996730cc --- /dev/null +++ b/changelog.d/12036.misc @@ -0,0 +1 @@ +Rename `shared_rooms` to `mutual_rooms` (MSC2666), as per proposal changes. \ No newline at end of file diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 762808a571..57c4773edc 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -32,6 +32,7 @@ from synapse.rest.client import ( knock, login as v1_login, logout, + mutual_rooms, notifications, openid, password_policy, @@ -49,7 +50,6 @@ from synapse.rest.client import ( room_keys, room_upgrade_rest_servlet, sendtodevice, - shared_rooms, sync, tags, thirdparty, @@ -132,4 +132,4 @@ class ClientRestResource(JsonResource): admin.register_servlets_for_client_rest_resource(hs, client_resource) # unstable - shared_rooms.register_servlets(hs, client_resource) + mutual_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py new file mode 100644 index 0000000000..d3872a76c8 --- /dev/null +++ b/synapse/rest/client/mutual_rooms.py @@ -0,0 +1,76 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UserMutualRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/mutual_rooms/(?P[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.user_directory_active = hs.config.server.update_user_directory + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + + rooms = await self.store.get_mutual_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + UserMutualRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py deleted file mode 100644 index e669fa7890..0000000000 --- a/synapse/rest/client/shared_rooms.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Tuple - -from synapse.api.errors import Codes, SynapseError -from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class UserSharedRoomsServlet(RestServlet): - """ - GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", - releases=(), # This is an unstable feature - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.user_directory_active = hs.config.server.update_user_directory - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - - if not self.user_directory_active: - raise SynapseError( - code=400, - msg="The user directory is disabled on this server. Cannot determine shared rooms.", - errcode=Codes.FORBIDDEN, - ) - - UserID.from_string(user_id) - - requester = await self.auth.get_user_by_req(request) - if user_id == requester.user.to_string(): - raise SynapseError( - code=400, - msg="You cannot request a list of shared rooms with yourself", - errcode=Codes.FORBIDDEN, - ) - rooms = await self.store.get_shared_rooms_for_users( - requester.user.to_string(), user_id - ) - - return 200, {"joined": list(rooms)} - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 55cc9178f0..0595df01d3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -730,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_shared_rooms_for_users( + async def get_mutual_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: """ @@ -744,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): A set of room ID's that the users share. """ - def _get_shared_rooms_for_users_txn( + def _get_mutual_rooms_for_users_txn( txn: LoggingTransaction, ) -> List[Dict[str, str]]: txn.execute( @@ -768,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return rows rows = await self.db_pool.runInteraction( - "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn ) return {row["room_id"] for row in rows} diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py new file mode 100644 index 0000000000..7b7d283bb6 --- /dev/null +++ b/tests/rest/client/test_mutual_rooms.py @@ -0,0 +1,146 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, mutual_rooms, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.server import FakeChannel + + +class UserMutualRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserMutualRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + mutual_rooms.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.handler = hs.get_user_directory_handler() + + def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: + return self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s" + % other_user, + access_token=token, + ) + + def test_shared_room_list_public(self) -> None: + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True) + + def test_shared_room_list_private(self) -> None: + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + self._check_mutual_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) + + def test_shared_room_list_mixed(self) -> None: + """ + The shared room list between two users should contain both public and private + rooms. + """ + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_mutual_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ) -> None: + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) + + def test_shared_room_list_after_leave(self) -> None: + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + # Check user1's view of shared rooms with user2 + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 + channel = self._get_mutual_rooms(u2_token, u1) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py deleted file mode 100644 index 3818b7b14b..0000000000 --- a/tests/rest/client/test_shared_rooms.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms -from synapse.server import HomeServer -from synapse.util import Clock - -from tests import unittest -from tests.server import FakeChannel - - -class UserSharedRoomsTest(unittest.HomeserverTestCase): - """ - Tests the UserSharedRoomsServlet. - """ - - servlets = [ - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - shared_rooms.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.handler = hs.get_user_directory_handler() - - def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: - return self.make_request( - "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" - % other_user, - access_token=token, - ) - - def test_shared_room_list_public(self) -> None: - """ - A room should show up in the shared list of rooms between two users - if it is public. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - - def test_shared_room_list_private(self) -> None: - """ - A room should show up in the shared list of rooms between two users - if it is private. - """ - self._check_shared_rooms_with( - room_one_is_public=False, room_two_is_public=False - ) - - def test_shared_room_list_mixed(self) -> None: - """ - The shared room list between two users should contain both public and private - rooms. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) - - def _check_shared_rooms_with( - self, room_one_is_public: bool, room_two_is_public: bool - ) -> None: - """Checks that shared public or private rooms between two users appear in - their shared room lists - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - # Create a room. user1 invites user2, who joins - room_id_one = self.helper.create_room_as( - u1, is_public=room_one_is_public, tok=u1_token - ) - self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_one, user=u2, tok=u2_token) - - # Check shared rooms from user1's perspective. - # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 1) - self.assertEqual(channel.json_body["joined"][0], room_id_one) - - # Create another room and invite user2 to it - room_id_two = self.helper.create_room_as( - u1, is_public=room_two_is_public, tok=u1_token - ) - self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_two, user=u2, tok=u2_token) - - # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 2) - for room_id_id in channel.json_body["joined"]: - self.assertIn(room_id_id, [room_id_one, room_id_two]) - - def test_shared_room_list_after_leave(self) -> None: - """ - A room should no longer be considered shared if the other - user has left it. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 1) - self.assertEqual(channel.json_body["joined"][0], room) - - self.helper.leave(room, user=u1, tok=u1_token) - - # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 0) - - # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 0) -- cgit 1.5.1 From 5436b014f44699093dd75d0ecbf26c434feecaa0 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 24 Mar 2022 11:19:41 +0100 Subject: Optionally include account validity in MSC3720 account status responses (#12266) --- changelog.d/12266.misc | 1 + synapse/config/server.py | 4 +++ synapse/handlers/account.py | 11 ++++++++ tests/rest/client/test_account.py | 58 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12266.misc (limited to 'tests/rest') diff --git a/changelog.d/12266.misc b/changelog.d/12266.misc new file mode 100644 index 0000000000..59e2718370 --- /dev/null +++ b/changelog.d/12266.misc @@ -0,0 +1 @@ +Optionally include account validity expiration information to experimental [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) account status responses. diff --git a/synapse/config/server.py b/synapse/config/server.py index 49cd0a4f19..38de4b8000 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -676,6 +676,10 @@ class ServerConfig(Config): ): raise ConfigError("'custom_template_directory' must be a string") + self.use_account_validity_in_account_status: bool = ( + config.get("use_account_validity_in_account_status") or False + ) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index d5badf635b..c05a14304c 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -26,6 +26,10 @@ class AccountHandler: self._main_store = hs.get_datastores().main self._is_mine = hs.is_mine self._federation_client = hs.get_federation_client() + self._use_account_validity_in_account_status = ( + hs.config.server.use_account_validity_in_account_status + ) + self._account_validity_handler = hs.get_account_validity_handler() async def get_account_statuses( self, @@ -106,6 +110,13 @@ class AccountHandler: "deactivated": userinfo.is_deactivated, } + if self._use_account_validity_in_account_status: + status[ + "org.matrix.expired" + ] = await self._account_validity_handler.is_user_expired( + user_id.to_string() + ) + return status async def _get_remote_account_statuses( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index def836054d..27946febff 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -31,7 +31,7 @@ from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest @@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[users[2]], ) + @unittest.override_config( + { + "use_account_validity_in_account_status": True, + } + ) + def test_no_account_validity(self) -> None: + """Tests that if we decide to include account validity in the response but no + account validity 'is_user_expired' callback is provided, we default to marking all + users as not expired. + """ + user = self.register_user("someuser", "password") + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + "org.matrix.expired": False, + }, + }, + expected_failures=[], + ) + + @unittest.override_config( + { + "use_account_validity_in_account_status": True, + } + ) + def test_account_validity_expired(self) -> None: + """Test that if we decide to include account validity in the response and the user + is expired, we return the correct info. + """ + user = self.register_user("someuser", "password") + + async def is_expired(user_id: str) -> bool: + # We can't blindly say everyone is expired, otherwise the request to get the + # account status will fail. + return UserID.from_string(user_id).localpart == "someuser" + + self.hs.get_account_validity_handler()._is_user_expired_callbacks.append( + is_expired + ) + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + "org.matrix.expired": True, + }, + }, + expected_failures=[], + ) + def _test_status( self, users: Optional[List[str]], -- cgit 1.5.1 From 4df10d32148ae29f792afc68ff774bcbd1915cea Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 24 Mar 2022 10:25:42 -0400 Subject: Do not consider events by ignored users for relations (#12285) Filter the events returned from `/relations` for the requester's ignored users in a similar way to `/messages` (and `/sync`). --- changelog.d/12227.bugfix | 1 + changelog.d/12227.misc | 1 - changelog.d/12232.bugfix | 1 + changelog.d/12232.misc | 1 - changelog.d/12285.bugfix | 1 + synapse/handlers/relations.py | 9 ++++- tests/rest/client/test_relations.py | 80 ++++++++++++++++++++++++++++++++++++- 7 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12227.bugfix delete mode 100644 changelog.d/12227.misc create mode 100644 changelog.d/12232.bugfix delete mode 100644 changelog.d/12232.misc create mode 100644 changelog.d/12285.bugfix (limited to 'tests/rest') diff --git a/changelog.d/12227.bugfix b/changelog.d/12227.bugfix new file mode 100644 index 0000000000..1a7dccf465 --- /dev/null +++ b/changelog.d/12227.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where events from ignored users were still considered for relations. diff --git a/changelog.d/12227.misc b/changelog.d/12227.misc deleted file mode 100644 index 41c9dcbd37..0000000000 --- a/changelog.d/12227.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/changelog.d/12232.bugfix b/changelog.d/12232.bugfix new file mode 100644 index 0000000000..1a7dccf465 --- /dev/null +++ b/changelog.d/12232.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where events from ignored users were still considered for relations. diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc deleted file mode 100644 index 4a4132edff..0000000000 --- a/changelog.d/12232.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor relations tests to improve code re-use. diff --git a/changelog.d/12285.bugfix b/changelog.d/12285.bugfix new file mode 100644 index 0000000000..1a7dccf465 --- /dev/null +++ b/changelog.d/12285.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where events from ignored users were still considered for relations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 57135d4519..73217d135d 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -21,6 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.types import JsonDict, Requester, StreamToken +from synapse.visibility import filter_events_for_client if TYPE_CHECKING: from synapse.server import HomeServer @@ -62,6 +63,7 @@ class BundledAggregations: class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main + self._storage = hs.get_storage() self._auth = hs.get_auth() self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() @@ -103,7 +105,8 @@ class RelationsHandler: user_id = requester.user.to_string() - await self._auth.check_user_in_room_or_world_readable( + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) @@ -130,6 +133,10 @@ class RelationsHandler: [c["event_id"] for c in pagination_chunk.chunk] ) + events = await filter_events_for_client( + self._storage, user_id, events, is_peeking=(member_event_id is None) + ) + now = self._clock.time_msec() # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 329690f8f7..fe97a0b3dd 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -20,7 +20,7 @@ from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync from synapse.server import HomeServer @@ -1324,6 +1324,84 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertIn("m.relations", parent_event["unsigned"]) +class RelationIgnoredUserTestCase(BaseRelationsTestCase): + """Relations sent from an ignored user should be ignored.""" + + def _test_ignored_user( + self, allowed_event_ids: List[str], ignored_event_ids: List[str] + ) -> None: + """ + Fetch the relations and ensure they're all there, then ignore user2, and + repeat. + """ + # Get the relations. + event_ids = self._get_related_events() + self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids) + + # Ignore user2 and re-do the requests. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Get the relations. + event_ids = self._get_related_events() + self.assertCountEqual(event_ids, allowed_event_ids) + + def test_annotation(self) -> None: + """Annotations should ignore""" + # Send 2 from us, 2 from the to be ignored user. + allowed_event_ids = [] + ignored_event_ids = [] + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + allowed_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b") + allowed_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="a", + access_token=self.user2_token, + ) + ignored_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="c", + access_token=self.user2_token, + ) + ignored_event_ids.append(channel.json_body["event_id"]) + + self._test_ignored_user(allowed_event_ids, ignored_event_ids) + + def test_reference(self) -> None: + """Annotations should ignore""" + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + allowed_event_ids = [channel.json_body["event_id"]] + + channel = self._send_relation( + RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token + ) + ignored_event_ids = [channel.json_body["event_id"]] + + self._test_ignored_user(allowed_event_ids, ignored_event_ids) + + def test_thread(self) -> None: + """Annotations should ignore""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + allowed_event_ids = [channel.json_body["event_id"]] + + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + ignored_event_ids = [channel.json_body["event_id"]] + + self._test_ignored_user(allowed_event_ids, ignored_event_ids) + + class RelationRedactionTestCase(BaseRelationsTestCase): """ Test the behaviour of relations when the parent or child event is redacted. -- cgit 1.5.1 From fffb3c4c8f67c271a723855835c2ea0fb83fc33f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 25 Mar 2022 13:28:42 +0000 Subject: Always allow the empty string as an avatar_url. (#12261) Hopefully this fixes #12257. Co-authored-by: Patrick Cloke --- changelog.d/12261.bugfix | 1 + synapse/handlers/profile.py | 6 ++++++ tests/handlers/test_profile.py | 6 ++++++ tests/rest/admin/test_user.py | 19 +++++++++++++++++++ 4 files changed, 32 insertions(+) create mode 100644 changelog.d/12261.bugfix (limited to 'tests/rest') diff --git a/changelog.d/12261.bugfix b/changelog.d/12261.bugfix new file mode 100644 index 0000000000..1bfde4c380 --- /dev/null +++ b/changelog.d/12261.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.52 where admins could not deactivate and GDPR-erase a user if Synapse was configured with limits on avatars. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6554c0d3c2..239b0aa744 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -336,12 +336,18 @@ class ProfileHandler: """Check that the size and content type of the avatar at the given MXC URI are within the configured limits. + If the given `mxc` is empty, no checks are performed. (Users are always able to + unset their avatar.) + Args: mxc: The MXC URI at which the avatar can be found. Returns: A boolean indicating whether the file can be allowed to be set as an avatar. """ + if mxc == "": + return True + if not self.max_avatar_size and not self.allowed_avatar_mimetypes: return True diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 08733a9f2d..1ec105c373 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -267,6 +267,12 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertTrue(res) + @unittest.override_config({"max_avatar_size": 50}) + def test_avatar_constraints_allow_empty_avatar_url(self) -> None: + """An empty avatar is always permitted.""" + res = self.get_success(self.handler.check_avatar_size_and_mime_type("")) + self.assertTrue(res) + @unittest.override_config({"max_avatar_size": 50}) def test_avatar_constraints_missing(self) -> None: """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a60ea0a563..bef911d5df 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", True) + @override_config({"max_avatar_size": 1234}) + def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None: + """Check we can erase a user whose avatar is the empty string. + + Reproduces #12257. + """ + # Patch `self.other_user` to have an empty string as their avatar. + self.get_success(self.store.set_profile_avatar_url("user", "")) + + # Check we can still erase them. + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"erase": True}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self._is_erased("@user:test", True) + def test_deactivate_user_erase_false(self) -> None: """ Test deactivating a user and set `erase` to `false` -- cgit 1.5.1