From 9f2016e96e800460c390c2f2de85797910954ca6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 21 Jan 2022 09:19:56 +0000 Subject: Drop unused table `public_room_list_stream`. (#11795) This is a follow-up to #10565. --- tests/rest/admin/test_room.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests/rest') diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 3495a0366a..23da0ad736 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2468,7 +2468,6 @@ PURGE_TABLES = [ "event_search", "events", "group_rooms", - "public_room_list_stream", "receipts_graph", "receipts_linearized", "room_aliases", -- cgit 1.5.1 From b784299cbc121d27d7dadd0a4a96f4657244a4e9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Jan 2022 05:31:31 -0500 Subject: Do not try to serialize raw aggregations dict. (#11791) --- changelog.d/11612.bugfix | 1 + changelog.d/11612.misc | 1 - changelog.d/11791.bugfix | 1 + synapse/events/utils.py | 4 +- synapse/rest/admin/rooms.py | 13 ++--- synapse/rest/client/room.py | 11 ++-- tests/rest/client/test_relations.py | 108 ++++++++++++++++++++++++------------ 7 files changed, 85 insertions(+), 54 deletions(-) create mode 100644 changelog.d/11612.bugfix delete mode 100644 changelog.d/11612.misc create mode 100644 changelog.d/11791.bugfix (limited to 'tests/rest') diff --git a/changelog.d/11612.bugfix b/changelog.d/11612.bugfix new file mode 100644 index 0000000000..842f6892fd --- /dev/null +++ b/changelog.d/11612.bugfix @@ -0,0 +1 @@ +Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/changelog.d/11612.misc b/changelog.d/11612.misc deleted file mode 100644 index 2d886169c5..0000000000 --- a/changelog.d/11612.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid database access in the JSON serialization process. diff --git a/changelog.d/11791.bugfix b/changelog.d/11791.bugfix new file mode 100644 index 0000000000..842f6892fd --- /dev/null +++ b/changelog.d/11791.bugfix @@ -0,0 +1 @@ +Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index de0e0c1731..918adeecf8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -402,7 +402,7 @@ class EventClientSerializer: if bundle_aggregations: event_aggregations = bundle_aggregations.get(event.event_id) if event_aggregations: - self._injected_bundled_aggregations( + self._inject_bundled_aggregations( event, time_now, bundle_aggregations[event.event_id], @@ -411,7 +411,7 @@ class EventClientSerializer: return serialized_event - def _injected_bundled_aggregations( + def _inject_bundled_aggregations( self, event: EventBase, time_now: int, diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 2e714ac87b..efe25fe7eb 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -744,20 +744,15 @@ class RoomEventContextServlet(RestServlet): ) time_now = self.clock.time_msec() + aggregations = results.pop("aggregations", None) results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_before"], time_now, bundle_aggregations=aggregations ) results["event"] = self._event_serializer.serialize_event( - results["event"], - time_now, - bundle_aggregations=results["aggregations"], + results["event"], time_now, bundle_aggregations=aggregations ) results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_after"], time_now, bundle_aggregations=aggregations ) results["state"] = self._event_serializer.serialize_events( results["state"], time_now diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 31fd329a38..90bb9142a0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -714,18 +714,15 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + aggregations = results.pop("aggregations", None) results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_before"], time_now, bundle_aggregations=aggregations ) results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=results["aggregations"] + results["event"], time_now, bundle_aggregations=aggregations ) results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_after"], time_now, bundle_aggregations=aggregations ) results["state"] = self._event_serializer.serialize_events( results["state"], time_now diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 4b20ab0e3e..c9b220e73d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -21,6 +21,7 @@ from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.types import JsonDict from tests import unittest from tests.server import FakeChannel @@ -454,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_bundled_aggregations(self): - """Test that annotations, references, and threads get correctly bundled.""" + """ + Test that annotations, references, and threads get correctly bundled. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + + See test_edit for a similar test for edits. + """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -482,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - def assert_bundle(actual): + def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") # Ensure the fields are as expected. self.assertCountEqual( - actual.keys(), + relations_dict.keys(), ( RelationTypes.ANNOTATION, RelationTypes.REFERENCE, @@ -503,20 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"type": "m.reaction", "key": "b", "count": 1}, ] }, - actual[RelationTypes.ANNOTATION], + relations_dict[RelationTypes.ANNOTATION], ) self.assertEquals( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - actual[RelationTypes.REFERENCE], + relations_dict[RelationTypes.REFERENCE], ) self.assertEquals( 2, - actual[RelationTypes.THREAD].get("count"), + relations_dict[RelationTypes.THREAD].get("count"), ) self.assertTrue( - actual[RelationTypes.THREAD].get("current_user_participated") + relations_dict[RelationTypes.THREAD].get("current_user_participated") ) # The latest thread event has some fields that don't matter. self.assert_dict( @@ -533,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "type": "m.room.test", "user_id": self.user_id, }, - actual[RelationTypes.THREAD].get("latest_event"), + relations_dict[RelationTypes.THREAD].get("latest_event"), ) - def _find_and_assert_event(events): - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - break - else: - raise AssertionError(f"Event {self.parent_id} not found in chunk") - assert_bundle(event["unsigned"].get("m.relations")) - # Request the event directly. channel = self.make_request( "GET", @@ -554,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["unsigned"].get("m.relations")) + assert_bundle(channel.json_body) # Request the room messages. channel = self.make_request( @@ -563,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - _find_and_assert_event(channel.json_body["chunk"]) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. channel = self.make_request( @@ -572,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) - - # Note that /relations is tested separately in test_aggregation_get_event_for_thread - # since it needs different data configured. + self._find_event_in_chunk(room_timeline["events"]) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included @@ -777,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + assert_bundle(channel.json_body) - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + # Request sync, but limit the timeline so it becomes limited (and includes + # bundled aggregations). + filter = urllib.parse.quote_plus( + '{"room": {"timeline": {"limit": 2}}}'.encode() + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token ) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_multi_edit(self): """Test that multiple edits, including attempts by people who @@ -1102,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["chunk"], []) + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + def _send_relation( self, relation_type: str, -- cgit 1.5.1 From 807efd26aec9b65c6a2f02d10fd139095a5b3387 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Jan 2022 08:58:18 -0500 Subject: Support rendering previews with data: URLs in them (#11767) Images which are data URLs will no longer break URL previews and will properly be "downloaded" and thumbnailed. --- changelog.d/11767.bugfix | 1 + synapse/rest/media/v1/preview_html.py | 31 +- synapse/rest/media/v1/preview_url_resource.py | 224 ++++++++---- tests/rest/media/v1/test_html_preview.py | 481 ++++++++++++++++++++++++++ tests/rest/media/v1/test_url_preview.py | 81 ++++- tests/server.py | 2 +- tests/test_preview.py | 449 ------------------------ 7 files changed, 747 insertions(+), 522 deletions(-) create mode 100644 changelog.d/11767.bugfix create mode 100644 tests/rest/media/v1/test_html_preview.py delete mode 100644 tests/test_preview.py (limited to 'tests/rest') diff --git a/changelog.d/11767.bugfix b/changelog.d/11767.bugfix new file mode 100644 index 0000000000..3e344747f4 --- /dev/null +++ b/changelog.d/11767.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when previewing Reddit URLs which do not contain an image. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 30b067dd42..872a9e72e8 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -321,14 +321,33 @@ def _iterate_over_text( def rebase_url(url: str, base: str) -> str: - base_parts = list(urlparse.urlparse(base)) + """ + Resolves a potentially relative `url` against an absolute `base` URL. + + For example: + + >>> rebase_url("subpage", "https://example.com/foo/") + 'https://example.com/foo/subpage' + >>> rebase_url("sibling", "https://example.com/foo") + 'https://example.com/sibling' + >>> rebase_url("/bar", "https://example.com/foo/") + 'https://example.com/bar' + >>> rebase_url("https://alice.com/a/", "https://example.com/foo/") + 'https://alice.com/a' + """ + base_parts = urlparse.urlparse(base) + # Convert the parsed URL to a list for (potential) modification. url_parts = list(urlparse.urlparse(url)) - if not url_parts[0]: # fix up schema - url_parts[0] = base_parts[0] or "http" - if not url_parts[1]: # fix up hostname - url_parts[1] = base_parts[1] + # Add a scheme, if one does not exist. + if not url_parts[0]: + url_parts[0] = base_parts.scheme or "http" + # Fix up the hostname, if this is not a data URL. + if url_parts[0] != "data" and not url_parts[1]: + url_parts[1] = base_parts.netloc + # If the path does not start with a /, nest it under the base path's last + # directory. if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2] return urlparse.urlunparse(url_parts) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e8881bc870..efd84ced8f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -21,8 +21,9 @@ import re import shutil import sys import traceback -from typing import TYPE_CHECKING, Iterable, Optional, Tuple +from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple from urllib import parse as urlparse +from urllib.request import urlopen import attr @@ -70,6 +71,17 @@ ONE_DAY = 24 * ONE_HOUR IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DownloadResult: + length: int + uri: str + response_code: int + media_type: str + download_name: Optional[str] + expires: int + etag: Optional[str] + + @attr.s(slots=True, frozen=True, auto_attribs=True) class MediaInfo: """ @@ -256,7 +268,7 @@ class PreviewUrlResource(DirectServeJsonResource): if oembed_url: url_to_download = oembed_url - media_info = await self._download_url(url_to_download, user) + media_info = await self._handle_url(url_to_download, user) logger.debug("got media_info of '%s'", media_info) @@ -297,7 +309,9 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._download_url(oembed_url, user) + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) ( og_from_oembed, author_name, @@ -367,7 +381,135 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - async def _download_url(self, url: str, user: UserID) -> MediaInfo: + async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult: + """ + Fetches a remote URL and parses the headers. + + Args: + url: The URL to fetch. + output_stream: The stream to write the content to. + + Returns: + A tuple of: + Media length, URL downloaded, the HTTP response code, + the media type, the downloaded file name, the number of + milliseconds the result is valid for, the etag header. + """ + + try: + logger.debug("Trying to get preview for url '%s'", url) + length, headers, uri, code = await self.client.get_file( + url, + output_stream=output_stream, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url, e) + + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + + return DownloadResult( + length, uri, code, media_type, download_name, expires, etag + ) + + async def _parse_data_url( + self, url: str, output_stream: BinaryIO + ) -> DownloadResult: + """ + Parses a data: URL. + + Args: + url: The URL to parse. + output_stream: The stream to write the content to. + + Returns: + A tuple of: + Media length, URL downloaded, the HTTP response code, + the media type, the downloaded file name, the number of + milliseconds the result is valid for, the etag header. + """ + + try: + logger.debug("Trying to parse data url '%s'", url) + with urlopen(url) as url_info: + # TODO Can this be more efficient. + output_stream.write(url_info.read()) + except Exception as e: + logger.warning("Error parsing data: URL %s: %r", url, e) + + raise SynapseError( + 500, + "Failed to parse data URL: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + + return DownloadResult( + # Read back the length that has been written. + length=output_stream.tell(), + uri=url, + # If it was parsed, consider this a 200 OK. + response_code=200, + # urlopen shoves the media-type from the data URL into the content type + # header object. + media_type=url_info.headers.get_content_type(), + # Some features are not supported by data: URLs. + download_name=None, + expires=ONE_HOUR, + etag=None, + ) + + async def _handle_url( + self, url: str, user: UserID, allow_data_urls: bool = False + ) -> MediaInfo: + """ + Fetches content from a URL and parses the result to generate a MediaInfo. + + It uses the media storage provider to persist the fetched content and + stores the mapping into the database. + + Args: + url: The URL to fetch. + user: The user who ahs requested this URL. + allow_data_urls: True if data URLs should be allowed. + + Returns: + A MediaInfo object describing the fetched content. + """ + # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -377,61 +519,27 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) with self.media_storage.store_into_file(file_info) as (f, fname, finish): - try: - logger.debug("Trying to get preview for url '%s'", url) - length, headers, uri, code = await self.client.get_file( - url, - output_stream=f, - max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url, e) - - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - await finish() + if url.startswith("data:"): + if not allow_data_urls: + raise SynapseError( + 500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN + ) - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") + download_result = await self._parse_data_url(url, f) else: - media_type = "application/octet-stream" + download_result = await self._download_url(url, f) - download_name = get_filename_from_headers(headers) - - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - expires = ONE_HOUR - etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + await finish() try: time_now_ms = self.clock.time_msec() await self.store.store_local_media( media_id=file_id, - media_type=media_type, + media_type=download_result.media_type, time_now_ms=time_now_ms, - upload_name=download_name, - media_length=length, + upload_name=download_result.download_name, + media_length=download_result.length, user_id=user, url_cache=url, ) @@ -444,16 +552,16 @@ class PreviewUrlResource(DirectServeJsonResource): raise return MediaInfo( - media_type=media_type, - media_length=length, - download_name=download_name, + media_type=download_result.media_type, + media_length=download_result.length, + download_name=download_result.download_name, created_ts_ms=time_now_ms, filesystem_id=file_id, filename=fname, - uri=uri, - response_code=code, - expires=expires, - etag=etag, + uri=download_result.uri, + response_code=download_result.response_code, + expires=download_result.expires, + etag=download_result.etag, ) async def _precache_image_url( @@ -474,8 +582,8 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._download_url( - rebase_url(og["og:image"], media_info.uri), user + image_info = await self._handle_url( + rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True ) if _is_media(image_info.media_type): diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py new file mode 100644 index 0000000000..a4b57e3d1f --- /dev/null +++ b/tests/rest/media/v1/test_html_preview.py @@ -0,0 +1,481 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.media.v1.preview_html import ( + _get_html_media_encodings, + decode_body, + parse_html_to_open_graph, + rebase_url, + summarize_paragraphs, +) + +from tests import unittest + +try: + import lxml +except ImportError: + lxml = None + + +class SummarizeTestCase(unittest.TestCase): + if not lxml: + skip = "url preview feature requires lxml" + + def test_long_summarize(self): + example_paras = [ + """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in + Troms county, Norway. The administrative centre of the municipality is + the city of Tromsø. Outside of Norway, Tromso and Tromsö are + alternative spellings of the city.Tromsø is considered the northernmost + city in the world with a population above 50,000. The most populous town + north of it is Alta, Norway, with a population of 14,272 (2013).""", + """Tromsø lies in Northern Norway. The municipality has a population of + (2015) 72,066, but with an annual influx of students it has over 75,000 + most of the year. It is the largest urban area in Northern Norway and the + third largest north of the Arctic Circle (following Murmansk and Norilsk). + Most of Tromsø, including the city centre, is located on the island of + Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, + Tromsøya had a population of 36,088. Substantial parts of the urban area + are also situated on the mainland to the east, and on parts of Kvaløya—a + large island to the west. Tromsøya is connected to the mainland by the Tromsø + Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the + Sandnessund Bridge. Tromsø Airport connects the city to many destinations + in Europe. The city is warmer than most other places located on the same + latitude, due to the warming effect of the Gulf Stream.""", + """The city centre of Tromsø contains the highest number of old wooden + houses in Northern Norway, the oldest house dating from 1789. The Arctic + Cathedral, a modern church from 1965, is probably the most famous landmark + in Tromsø. The city is a cultural centre for its region, with several + festivals taking place in the summer. Some of Norway's best-known + musicians, Torbjørn Brundtland and Svein Berge of the electronica duo + Röyksopp and Lene Marlin grew up and started their careers in Tromsø. + Noted electronic musician Geir Jenssen also hails from Tromsø.""", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEqual( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway. The administrative centre of the municipality is" + " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + " alternative spellings of the city.Tromsø is considered the northernmost" + " city in the world with a population above 50,000. The most populous town" + " north of it is Alta, Norway, with a population of 14,272 (2013).", + ) + + desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) + + self.assertEqual( + desc, + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. It is the largest urban area in Northern Norway and the" + " third largest north of the Arctic Circle (following Murmansk and Norilsk)." + " Most of Tromsø, including the city centre, is located on the island of" + " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + " Tromsøya had a population of 36,088. Substantial parts of the urban…", + ) + + def test_short_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + "The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEqual( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + ) + + def test_small_then_large_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + " The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + self.assertEqual( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. The city centre of Tromsø contains the highest number" + " of old wooden houses in Northern Norway, the oldest house dating from" + " 1789. The Arctic Cathedral, a modern church from…", + ) + + +class CalcOgTestCase(unittest.TestCase): + if not lxml: + skip = "url preview feature requires lxml" + + def test_simple(self): + html = b""" + + Foo + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_comment(self): + html = b""" + + Foo + + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_comment2(self): + html = b""" + + Foo + + Some text. + + Some more text. +

Text

+ More text + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual( + og, + { + "og:title": "Foo", + "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", + }, + ) + + def test_script(self): + html = b""" + + Foo + + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_missing_title(self): + html = b""" + + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_h1_as_title(self): + html = b""" + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) + + def test_missing_title_and_broken_h1(self): + html = b""" + + +

+ Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_empty(self): + """Test a body with no data in it.""" + html = b"" + tree = decode_body(html, "http://example.com/test.html") + self.assertIsNone(tree) + + def test_no_tree(self): + """A valid body with no tree in it.""" + html = b"\x00" + tree = decode_body(html, "http://example.com/test.html") + self.assertIsNone(tree) + + def test_xml(self): + """Test decoding XML and ensure it works properly.""" + # Note that the strip() call is important to ensure the xml tag starts + # at the initial byte. + html = b""" + + + + + FooSome text. + """.strip() + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_invalid_encoding(self): + """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" + html = b""" + + Foo + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_invalid_encoding2(self): + """A body which doesn't match the sent character encoding.""" + # Note that this contains an invalid UTF-8 sequence in the title. + html = b""" + + \xff\xff Foo + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) + + def test_windows_1252(self): + """A body which uses cp1252, but doesn't declare that.""" + html = b""" + + \xf3 + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) + + +class MediaEncodingTestCase(unittest.TestCase): + def test_meta_charset(self): + """A character encoding is found via the meta tag.""" + encodings = _get_html_media_encodings( + b""" + + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + # A less well-formed version. + encodings = _get_html_media_encodings( + b""" + + < meta charset = ascii> + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + def test_meta_charset_underscores(self): + """A character encoding contains underscore.""" + encodings = _get_html_media_encodings( + b""" + + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) + + def test_xml_encoding(self): + """A character encoding is found via the meta tag.""" + encodings = _get_html_media_encodings( + b""" + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + def test_meta_xml_encoding(self): + """Meta tags take precedence over XML encoding.""" + encodings = _get_html_media_encodings( + b""" + + + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) + + def test_content_type(self): + """A character encoding is found via the Content-Type header.""" + # Test a few variations of the header. + headers = ( + 'text/html; charset="ascii";', + "text/html;charset=ascii;", + 'text/html; charset="ascii"', + "text/html; charset=ascii", + 'text/html; charset="ascii;', + 'text/html; charset=ascii";', + ) + for header in headers: + encodings = _get_html_media_encodings(b"", header) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + def test_fallback(self): + """A character encoding cannot be found in the body or header.""" + encodings = _get_html_media_encodings(b"", "text/html") + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_duplicates(self): + """Ensure each encoding is only attempted once.""" + encodings = _get_html_media_encodings( + b""" + + + + + + """, + 'text/html; charset="UTF_8"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_unknown_invalid(self): + """A character encoding should be ignored if it is unknown or invalid.""" + encodings = _get_html_media_encodings( + b""" + + + + + """, + 'text/html; charset="invalid"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + +class RebaseUrlTestCase(unittest.TestCase): + def test_relative(self): + """Relative URLs should be resolved based on the context of the base URL.""" + self.assertEqual( + rebase_url("subpage", "https://example.com/foo/"), + "https://example.com/foo/subpage", + ) + self.assertEqual( + rebase_url("sibling", "https://example.com/foo"), + "https://example.com/sibling", + ) + self.assertEqual( + rebase_url("/bar", "https://example.com/foo/"), + "https://example.com/bar", + ) + + def test_absolute(self): + """Absolute URLs should not be modified.""" + self.assertEqual( + rebase_url("https://alice.com/a/", "https://example.com/foo/"), + "https://alice.com/a/", + ) + + def test_data(self): + """Data URLs should not be modified.""" + self.assertEqual( + rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), + "data:,Hello%2C%20World%21", + ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 16e904f15b..53f6186213 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -12,9 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json import os import re +from urllib.parse import urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -23,6 +25,7 @@ from twisted.test.proto_helpers import AccumulatingProtocol from synapse.config.oembed import OEmbedEndpointConfig from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.types import JsonDict from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest @@ -142,6 +145,14 @@ class URLPreviewTests(unittest.HomeserverTestCase): def create_test_resource(self): return self.hs.get_media_repository_resource() + def _assert_small_png(self, json_body: JsonDict) -> None: + """Assert properties from the SMALL_PNG test image.""" + self.assertTrue(json_body["og:image"].startswith("mxc://")) + self.assertEqual(json_body["og:image:height"], 1) + self.assertEqual(json_body["og:image:width"], 1) + self.assertEqual(json_body["og:image:type"], "image/png") + self.assertEqual(json_body["matrix:image:size"], 67) + def test_cache_returns_correct_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] @@ -569,6 +580,66 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_data_url(self): + """ + Requesting to preview a data URL is not supported. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + data = base64.b64encode(SMALL_PNG).decode() + + query_params = urlencode( + { + "url": f'' + } + ) + + channel = self.make_request( + "GET", + f"preview_url?{query_params}", + shorthand=False, + ) + self.pump() + + self.assertEqual(channel.code, 500) + + def test_inline_data_url(self): + """ + An inline image (as a data URL) should be parsed properly. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + data = base64.b64encode(SMALL_PNG) + + end_content = ( + b"" b'' b"" + ) % (data,) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self._assert_small_png(channel.json_body) + def test_oembed_photo(self): """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -626,10 +697,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) body = channel.json_body self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") - self.assertTrue(body["og:image"].startswith("mxc://")) - self.assertEqual(body["og:image:height"], 1) - self.assertEqual(body["og:image:width"], 1) - self.assertEqual(body["og:image:type"], "image/png") + self._assert_small_png(body) def test_oembed_rich(self): """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" @@ -820,10 +888,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345" ) - self.assertTrue(body["og:image"].startswith("mxc://")) - self.assertEqual(body["og:image:height"], 1) - self.assertEqual(body["og:image:width"], 1) - self.assertEqual(body["og:image:type"], "image/png") + self._assert_small_png(body) def _download_image(self): """Downloads an image into the URL cache. diff --git a/tests/server.py b/tests/server.py index a0cd14ea45..82990c2eb9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -313,7 +313,7 @@ def make_request( req = request(channel, site) req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. - req.content.seek(SEEK_END) + req.content.seek(0, SEEK_END) if access_token: req.requestHeaders.addRawHeader( diff --git a/tests/test_preview.py b/tests/test_preview.py deleted file mode 100644 index 46e02f483f..0000000000 --- a/tests/test_preview.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest.media.v1.preview_html import ( - _get_html_media_encodings, - decode_body, - parse_html_to_open_graph, - summarize_paragraphs, -) - -from . import unittest - -try: - import lxml -except ImportError: - lxml = None - - -class SummarizeTestCase(unittest.TestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def test_long_summarize(self): - example_paras = [ - """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: - Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in - Troms county, Norway. The administrative centre of the municipality is - the city of Tromsø. Outside of Norway, Tromso and Tromsö are - alternative spellings of the city.Tromsø is considered the northernmost - city in the world with a population above 50,000. The most populous town - north of it is Alta, Norway, with a population of 14,272 (2013).""", - """Tromsø lies in Northern Norway. The municipality has a population of - (2015) 72,066, but with an annual influx of students it has over 75,000 - most of the year. It is the largest urban area in Northern Norway and the - third largest north of the Arctic Circle (following Murmansk and Norilsk). - Most of Tromsø, including the city centre, is located on the island of - Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, - Tromsøya had a population of 36,088. Substantial parts of the urban area - are also situated on the mainland to the east, and on parts of Kvaløya—a - large island to the west. Tromsøya is connected to the mainland by the Tromsø - Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the - Sandnessund Bridge. Tromsø Airport connects the city to many destinations - in Europe. The city is warmer than most other places located on the same - latitude, due to the warming effect of the Gulf Stream.""", - """The city centre of Tromsø contains the highest number of old wooden - houses in Northern Norway, the oldest house dating from 1789. The Arctic - Cathedral, a modern church from 1965, is probably the most famous landmark - in Tromsø. The city is a cultural centre for its region, with several - festivals taking place in the summer. Some of Norway's best-known - musicians, Torbjørn Brundtland and Svein Berge of the electronica duo - Röyksopp and Lene Marlin grew up and started their careers in Tromsø. - Noted electronic musician Geir Jenssen also hails from Tromsø.""", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway. The administrative centre of the municipality is" - " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" - " alternative spellings of the city.Tromsø is considered the northernmost" - " city in the world with a population above 50,000. The most populous town" - " north of it is Alta, Norway, with a population of 14,272 (2013).", - ) - - desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. It is the largest urban area in Northern Norway and the" - " third largest north of the Arctic Circle (following Murmansk and Norilsk)." - " Most of Tromsø, including the city centre, is located on the island of" - " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - " Tromsøya had a population of 36,088. Substantial parts of the urban…", - ) - - def test_short_summarize(self): - example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - "The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - ) - - def test_small_then_large_summarize(self): - example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year." - " The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. The city centre of Tromsø contains the highest number" - " of old wooden houses in Northern Norway, the oldest house dating from" - " 1789. The Arctic Cathedral, a modern church from…", - ) - - -class CalcOgTestCase(unittest.TestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def test_simple(self): - html = b""" - - Foo - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_comment(self): - html = b""" - - Foo - - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_comment2(self): - html = b""" - - Foo - - Some text. - - Some more text. -

Text

- More text - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual( - og, - { - "og:title": "Foo", - "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", - }, - ) - - def test_script(self): - html = b""" - - Foo - - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_missing_title(self): - html = b""" - - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - - def test_h1_as_title(self): - html = b""" - - - -

Title

- - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - - def test_missing_title_and_broken_h1(self): - html = b""" - - -

- Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - - def test_empty(self): - """Test a body with no data in it.""" - html = b"" - tree = decode_body(html, "http://example.com/test.html") - self.assertIsNone(tree) - - def test_no_tree(self): - """A valid body with no tree in it.""" - html = b"\x00" - tree = decode_body(html, "http://example.com/test.html") - self.assertIsNone(tree) - - def test_xml(self): - """Test decoding XML and ensure it works properly.""" - # Note that the strip() call is important to ensure the xml tag starts - # at the initial byte. - html = b""" - - - - - FooSome text. - """.strip() - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_invalid_encoding(self): - """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" - html = b""" - - Foo - - Some text. - - - """ - tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_invalid_encoding2(self): - """A body which doesn't match the sent character encoding.""" - # Note that this contains an invalid UTF-8 sequence in the title. - html = b""" - - \xff\xff Foo - - Some text. - - - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) - - def test_windows_1252(self): - """A body which uses cp1252, but doesn't declare that.""" - html = b""" - - \xf3 - - Some text. - - - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) - - -class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self): - """A character encoding is found via the meta tag.""" - encodings = _get_html_media_encodings( - b""" - - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - # A less well-formed version. - encodings = _get_html_media_encodings( - b""" - - < meta charset = ascii> - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_meta_charset_underscores(self): - """A character encoding contains underscore.""" - encodings = _get_html_media_encodings( - b""" - - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - - def test_xml_encoding(self): - """A character encoding is found via the meta tag.""" - encodings = _get_html_media_encodings( - b""" - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_meta_xml_encoding(self): - """Meta tags take precedence over XML encoding.""" - encodings = _get_html_media_encodings( - b""" - - - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - - def test_content_type(self): - """A character encoding is found via the Content-Type header.""" - # Test a few variations of the header. - headers = ( - 'text/html; charset="ascii";', - "text/html;charset=ascii;", - 'text/html; charset="ascii"', - "text/html; charset=ascii", - 'text/html; charset="ascii;', - 'text/html; charset=ascii";', - ) - for header in headers: - encodings = _get_html_media_encodings(b"", header) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_fallback(self): - """A character encoding cannot be found in the body or header.""" - encodings = _get_html_media_encodings(b"", "text/html") - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - def test_duplicates(self): - """Ensure each encoding is only attempted once.""" - encodings = _get_html_media_encodings( - b""" - - - - - - """, - 'text/html; charset="UTF_8"', - ) - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - def test_unknown_invalid(self): - """A character encoding should be ignored if it is unknown or invalid.""" - encodings = _get_html_media_encodings( - b""" - - - - - """, - 'text/html; charset="invalid"', - ) - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) -- cgit 1.5.1 From 0d6cfea9b867a14fa0fa885b04c8cbfdb4a7c4a9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 25 Jan 2022 13:06:29 +0100 Subject: Add admin API to reset connection timeouts for remote server (#11639) * Fix get federation status of destination if no error occured --- changelog.d/11639.feature | 1 + docs/usage/administration/admin_api/federation.md | 40 +++++++++++++++- synapse/federation/transport/server/__init__.py | 16 ++++--- synapse/federation/transport/server/_base.py | 14 +++--- synapse/federation/transport/server/federation.py | 24 +++++++--- .../federation/transport/server/groups_local.py | 8 ++-- .../federation/transport/server/groups_server.py | 8 ++-- synapse/rest/admin/__init__.py | 6 ++- synapse/rest/admin/federation.py | 44 ++++++++++++++++- tests/rest/admin/test_federation.py | 55 ++++++++++++++++++++-- 10 files changed, 183 insertions(+), 33 deletions(-) create mode 100644 changelog.d/11639.feature (limited to 'tests/rest') diff --git a/changelog.d/11639.feature b/changelog.d/11639.feature new file mode 100644 index 0000000000..e9f6704f7a --- /dev/null +++ b/changelog.d/11639.feature @@ -0,0 +1 @@ +Add admin API to reset connection timeouts for remote server. \ No newline at end of file diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md index 8f9535f57b..5e609561a6 100644 --- a/docs/usage/administration/admin_api/federation.md +++ b/docs/usage/administration/admin_api/federation.md @@ -86,7 +86,7 @@ The following fields are returned in the JSON response body: - `next_token`: string representing a positive integer - Indication for pagination. See above. - `total` - integer - Total number of destinations. -# Destination Details API +## Destination Details API This API gets the retry timing info for a specific remote server. @@ -108,7 +108,45 @@ A response body like the following is returned: } ``` +**Parameters** + +The following parameters should be set in the URL: + +- `destination` - Name of the remote server. + **Response** The response fields are the same like in the `destinations` array in [List of destinations](#list-of-destinations) response. + +## Reset connection timeout + +Synapse makes federation requests to other homeservers. If a federation request fails, +Synapse will mark the destination homeserver as offline, preventing any future requests +to that server for a "cooldown" period. This period grows over time if the server +continues to fail its responses +([exponential backoff](https://en.wikipedia.org/wiki/Exponential_backoff)). + +Admins can cancel the cooldown period with this API. + +This API resets the retry timing for a specific remote server and tries to connect to +the remote server again. It does not wait for the next `retry_interval`. +The connection must have previously run into an error and `retry_last_ts` +([Destination Details API](#destination-details-api)) must not be equal to `0`. + +The connection attempt is carried out in the background and can take a while +even if the API already returns the http status 200. + +The API is: + +``` +POST /_synapse/admin/v1/federation/destinations//reset_connection + +{} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `destination` - Name of the remote server. diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 77b936361a..db4fe2c798 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type from typing_extensions import Literal @@ -36,17 +36,19 @@ from synapse.http.servlet import ( parse_integer_from_args, parse_string_from_args, ) -from synapse.server import HomeServer from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.ratelimitutils import FederationRateLimiter +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): + def __init__(self, hs: "HomeServer", servlet_groups: Optional[List[str]] = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in @@ -113,7 +115,7 @@ class PublicRoomList(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -203,7 +205,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -251,7 +253,7 @@ class OpenIdUserInfo(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -297,7 +299,7 @@ DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { def register_servlets( - hs: HomeServer, + hs: "HomeServer", resource: HttpServer, authenticator: Authenticator, ratelimiter: FederationRateLimiter, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index da1fbf8b63..2ca7c05835 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -15,7 +15,7 @@ import functools import logging import re -from typing import Any, Awaitable, Callable, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, cast from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX @@ -29,11 +29,13 @@ from synapse.logging.opentracing import ( start_active_span_follows_from, whitelisted_homeserver, ) -from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import parse_and_validate_server_name +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -46,7 +48,7 @@ class NoAuthenticationError(AuthenticationError): class Authenticator: - def __init__(self, hs: HomeServer): + def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname @@ -114,11 +116,11 @@ class Authenticator: # alive retry_timings = await self.store.get_destination_retry_timings(origin) if retry_timings and retry_timings.retry_last_ts: - run_in_background(self._reset_retry_timings, origin) + run_in_background(self.reset_retry_timings, origin) return origin - async def _reset_retry_timings(self, origin: str) -> None: + async def reset_retry_timings(self, origin: str) -> None: try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) @@ -227,7 +229,7 @@ class BaseFederationServlet: def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index beadfa422b..9c1ad5851f 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -12,7 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from typing_extensions import Literal @@ -30,11 +40,13 @@ from synapse.http.servlet import ( parse_string_from_args, parse_strings_from_args, ) -from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) issue_8631_logger = logging.getLogger("synapse.8631_debug") @@ -47,7 +59,7 @@ class BaseFederationServerServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -596,7 +608,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -670,7 +682,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -706,7 +718,7 @@ class RoomComplexityServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py index a12cd18d58..496472e1dc 100644 --- a/synapse/federation/transport/server/groups_local.py +++ b/synapse/federation/transport/server/groups_local.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Tuple, Type from synapse.api.errors import SynapseError from synapse.federation.transport.server._base import ( @@ -19,10 +19,12 @@ from synapse.federation.transport.server._base import ( BaseFederationServlet, ) from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.server import HomeServer from synapse.types import JsonDict, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter +if TYPE_CHECKING: + from synapse.server import HomeServer + class BaseGroupsLocalServlet(BaseFederationServlet): """Abstract base class for federation servlet classes which provides a groups local handler. @@ -32,7 +34,7 @@ class BaseGroupsLocalServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py index b30e92a5eb..851b50152e 100644 --- a/synapse/federation/transport/server/groups_server.py +++ b/synapse/federation/transport/server/groups_server.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Tuple, Type from typing_extensions import Literal @@ -22,10 +22,12 @@ from synapse.federation.transport.server._base import ( BaseFederationServlet, ) from synapse.http.servlet import parse_string_from_args -from synapse.server import HomeServer from synapse.types import JsonDict, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter +if TYPE_CHECKING: + from synapse.server import HomeServer + class BaseGroupsServerServlet(BaseFederationServlet): """Abstract base class for federation servlet classes which provides a groups server handler. @@ -35,7 +37,7 @@ class BaseGroupsServerServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 465e06772b..b1e49d51b7 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -41,7 +41,8 @@ from synapse.rest.admin.event_reports import ( EventReportsRestServlet, ) from synapse.rest.admin.federation import ( - DestinationsRestServlet, + DestinationResetConnectionRestServlet, + DestinationRestServlet, ListDestinationsRestServlet, ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet @@ -267,7 +268,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) - DestinationsRestServlet(hs).register(http_server) + DestinationResetConnectionRestServlet(hs).register(http_server) + DestinationRestServlet(hs).register(http_server) ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 8cd3fa189e..0f33f9e4da 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.federation.transport.server import Authenticator from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin @@ -90,7 +91,7 @@ class ListDestinationsRestServlet(RestServlet): return HTTPStatus.OK, response -class DestinationsRestServlet(RestServlet): +class DestinationRestServlet(RestServlet): """Get details of a destination. This needs user to have administrator access in Synapse. @@ -145,3 +146,44 @@ class DestinationsRestServlet(RestServlet): } return HTTPStatus.OK, response + + +class DestinationResetConnectionRestServlet(RestServlet): + """Reset destinations' connection timeouts and wake it up. + This needs user to have administrator access in Synapse. + + POST /_synapse/admin/v1/federation/destinations//reset_connection + {} + + returns: + 200 OK otherwise an error. + """ + + PATTERNS = admin_patterns( + "/federation/destinations/(?P[^/]+)/reset_connection$" + ) + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._authenticator = Authenticator(hs) + + async def on_POST( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not await self._store.is_destination_known(destination): + raise NotFoundError("Unknown destination") + + retry_timings = await self._store.get_destination_retry_timings(destination) + if not (retry_timings and retry_timings.retry_last_ts): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The retry timing does not need to be reset for this destination.", + ) + + # reset timings and wake up + await self._authenticator.reset_retry_timings(destination) + + return HTTPStatus.OK, {} diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index b70350b6f1..e2d3cff2a3 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -43,11 +43,15 @@ class FederationTestCase(unittest.HomeserverTestCase): @parameterized.expand( [ - ("/_synapse/admin/v1/federation/destinations",), - ("/_synapse/admin/v1/federation/destinations/dummy",), + ("GET", "/_synapse/admin/v1/federation/destinations"), + ("GET", "/_synapse/admin/v1/federation/destinations/dummy"), + ( + "POST", + "/_synapse/admin/v1/federation/destinations/dummy/reset_connection", + ), ] ) - def test_requester_is_no_admin(self, url: str) -> None: + def test_requester_is_no_admin(self, method: str, url: str) -> None: """ If the user is not a server admin, an error 403 is returned. """ @@ -56,7 +60,7 @@ class FederationTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") channel = self.make_request( - "GET", + method, url, content={}, access_token=other_user_tok, @@ -120,6 +124,16 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + # invalid destination + channel = self.make_request( + "POST", + self.url + "/dummy/reset_connection", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + def test_limit(self) -> None: """ Testing list of destinations with limit @@ -444,6 +458,39 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertIsNone(channel.json_body["failure_ts"]) self.assertIsNone(channel.json_body["last_successful_stream_ordering"]) + def test_destination_reset_connection(self) -> None: + """Reset timeouts and wake up destination.""" + self._create_destination("sub0.example.com", 100, 100, 100) + + channel = self.make_request( + "POST", + self.url + "/sub0.example.com/reset_connection", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + retry_timings = self.get_success( + self.store.get_destination_retry_timings("sub0.example.com") + ) + self.assertIsNone(retry_timings) + + def test_destination_reset_connection_not_required(self) -> None: + """Try to reset timeouts of a destination with no timeouts and get an error.""" + self._create_destination("sub0.example.com", None, 0, 0) + + channel = self.make_request( + "POST", + self.url + "/sub0.example.com/reset_connection", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "The retry timing does not need to be reset for this destination.", + channel.json_body["error"], + ) + def _create_destination( self, destination: str, -- cgit 1.5.1 From 6a72c910f180ee8b4bd78223775af48492769472 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 25 Jan 2022 17:11:40 +0100 Subject: Add admin API to get a list of federated rooms (#11658) --- changelog.d/11658.feature | 1 + docs/usage/administration/admin_api/federation.md | 60 +++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/federation.py | 56 ++++ synapse/storage/databases/main/transactions.py | 48 ++++ tests/rest/admin/test_federation.py | 302 ++++++++++++++++++++-- 6 files changed, 444 insertions(+), 25 deletions(-) create mode 100644 changelog.d/11658.feature (limited to 'tests/rest') diff --git a/changelog.d/11658.feature b/changelog.d/11658.feature new file mode 100644 index 0000000000..2ec9fb5eec --- /dev/null +++ b/changelog.d/11658.feature @@ -0,0 +1 @@ +Add an admin API to get a list of rooms that federate with a given remote homeserver. \ No newline at end of file diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md index 5e609561a6..60cbc5265e 100644 --- a/docs/usage/administration/admin_api/federation.md +++ b/docs/usage/administration/admin_api/federation.md @@ -119,6 +119,66 @@ The following parameters should be set in the URL: The response fields are the same like in the `destinations` array in [List of destinations](#list-of-destinations) response. +## Destination rooms + +This API gets the rooms that federate with a specific remote server. + +The API is: + +``` +GET /_synapse/admin/v1/federation/destinations//rooms +``` + +A response body like the following is returned: + +```json +{ + "rooms":[ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "stream_ordering": 8326 + }, + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "stream_ordering": 93534 + } + ], + "total": 2 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint again +with `from` set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more destinations +to paginate through. + +**Parameters** + +The following parameters should be set in the URL: + +- `destination` - Name of the remote server. + +The following query parameters are available: + +- `from` - Offset in the returned list. Defaults to `0`. +- `limit` - Maximum amount of destinations to return. Defaults to `100`. +- `dir` - Direction of room order by `room_id`. Either `f` for forwards or `b` for + backwards. Defaults to `f`. + +**Response** + +The following fields are returned in the JSON response body: + +- `rooms` - An array of objects, each containing information about a room. + Room objects contain the following fields: + - `room_id` - string - The ID of the room. + - `stream_ordering` - integer - The stream ordering of the most recent + successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation) + to this destination in this room. +- `next_token`: string representing a positive integer - Indication for pagination. See above. +- `total` - integer - Total number of destinations. + ## Reset connection timeout Synapse makes federation requests to other homeservers. If a federation request fails, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b1e49d51b7..9be9e33c8e 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -41,6 +41,7 @@ from synapse.rest.admin.event_reports import ( EventReportsRestServlet, ) from synapse.rest.admin.federation import ( + DestinationMembershipRestServlet, DestinationResetConnectionRestServlet, DestinationRestServlet, ListDestinationsRestServlet, @@ -268,6 +269,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + DestinationMembershipRestServlet(hs).register(http_server) DestinationResetConnectionRestServlet(hs).register(http_server) DestinationRestServlet(hs).register(http_server) ListDestinationsRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 0f33f9e4da..d162e0081e 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -148,6 +148,62 @@ class DestinationRestServlet(RestServlet): return HTTPStatus.OK, response +class DestinationMembershipRestServlet(RestServlet): + """Get list of rooms of a destination. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations//rooms?from=0&limit=10 + + returns: + 200 OK with a list of rooms if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + """ + + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)/rooms$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not await self._store.is_destination_known(destination): + raise NotFoundError("Unknown destination") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + + rooms, total = await self._store.get_destination_rooms_paginate( + destination, start, limit, direction + ) + response = {"rooms": rooms, "total": total} + if (start + limit) < total: + response["next_token"] = str(start + len(rooms)) + + return HTTPStatus.OK, response + + class DestinationResetConnectionRestServlet(RestServlet): """Reset destinations' connection timeouts and wake it up. This needs user to have administrator access in Synapse. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 4b78b4d098..ba79e19f7f 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -561,6 +561,54 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): "get_destinations_paginate_txn", get_destinations_paginate_txn ) + async def get_destination_rooms_paginate( + self, destination: str, start: int, limit: int, direction: str = "f" + ) -> Tuple[List[JsonDict], int]: + """Function to retrieve a paginated list of destination's rooms. + This will return a json list of rooms and the + total number of rooms. + + Args: + destination: the destination to query + start: start number to begin the query from + limit: number of rows to retrieve + direction: sort ascending or descending by room_id + Returns: + A tuple of a dict of rooms and a count of total rooms. + """ + + def get_destination_rooms_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT COUNT(*) as total_rooms + FROM destination_rooms + WHERE destination = ? + """ + txn.execute(sql, [destination]) + count = cast(Tuple[int], txn.fetchone())[0] + + rooms = self.db_pool.simple_select_list_paginate_txn( + txn=txn, + table="destination_rooms", + orderby="room_id", + start=start, + limit=limit, + retcols=("room_id", "stream_ordering"), + order_direction=order, + ) + return rooms, count + + return await self.db_pool.runInteraction( + "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn + ) + async def is_destination_known(self, destination: str) -> bool: """Check if a destination is known to the server.""" result = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index e2d3cff2a3..71068d16cd 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -20,7 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client import login +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -52,9 +52,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] ) def test_requester_is_no_admin(self, method: str, url: str) -> None: - """ - If the user is not a server admin, an error 403 is returned. - """ + """If the user is not a server admin, an error 403 is returned.""" self.register_user("user", "pass", admin=False) other_user_tok = self.login("user", "pass") @@ -70,9 +68,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: - """ - If parameters are invalid, an error is returned. - """ + """If parameters are invalid, an error is returned.""" # negative limit channel = self.make_request( @@ -135,9 +131,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_limit(self) -> None: - """ - Testing list of destinations with limit - """ + """Testing list of destinations with limit""" number_destinations = 20 self._create_destinations(number_destinations) @@ -155,9 +149,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_from(self) -> None: - """ - Testing list of destinations with a defined starting point (from) - """ + """Testing list of destinations with a defined starting point (from)""" number_destinations = 20 self._create_destinations(number_destinations) @@ -175,9 +167,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_limit_and_from(self) -> None: - """ - Testing list of destinations with a defined starting point and limit - """ + """Testing list of destinations with a defined starting point and limit""" number_destinations = 20 self._create_destinations(number_destinations) @@ -195,9 +185,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_next_token(self) -> None: - """ - Testing that `next_token` appears at the right place - """ + """Testing that `next_token` appears at the right place""" number_destinations = 20 self._create_destinations(number_destinations) @@ -256,9 +244,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) def test_list_all_destinations(self) -> None: - """ - List all destinations. - """ + """List all destinations.""" number_destinations = 5 self._create_destinations(number_destinations) @@ -277,9 +263,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_order_by(self) -> None: - """ - Testing order list with parameter `order_by` - """ + """Testing order list with parameter `order_by`""" def _order_test( expected_destination_list: List[str], @@ -543,3 +527,271 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertIn("retry_interval", c) self.assertIn("failure_ts", c) self.assertIn("last_successful_stream_ordering", c) + + +class DestinationMembershipTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.dest = "sub0.example.com" + self.url = f"/_synapse/admin/v1/federation/destinations/{self.dest}/rooms" + + # Record that we successfully contacted a destination in the DB. + self.get_success( + self.store.set_destination_retry_timings(self.dest, None, 0, 0) + ) + + def test_requester_is_no_admin(self) -> None: + """If the user is not a server admin, an error 403 is returned.""" + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + "GET", + self.url, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self) -> None: + """If parameters are invalid, an error is returned.""" + + # negative limit + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid destination + channel = self.make_request( + "GET", + "/_synapse/admin/v1/federation/destinations/%s/rooms" % ("invalid",), + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_limit(self) -> None: + """Testing list of destinations with limit""" + + number_rooms = 5 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url + "?limit=3", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 3) + self.assertEqual(channel.json_body["next_token"], "3") + self._check_fields(channel.json_body["rooms"]) + + def test_from(self) -> None: + """Testing list of rooms with a defined starting point (from)""" + + number_rooms = 10 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 5) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["rooms"]) + + def test_limit_and_from(self) -> None: + """Testing list of rooms with a defined starting point and limit""" + + number_rooms = 10 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url + "?from=3&limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(channel.json_body["next_token"], "8") + self.assertEqual(len(channel.json_body["rooms"]), 5) + self._check_fields(channel.json_body["rooms"]) + + def test_order_direction(self) -> None: + """Testing order list with parameter `dir`""" + number_rooms = 4 + self._create_destination_rooms(number_rooms) + + # get list in forward direction + channel_asc = self.make_request( + "GET", + self.url + "?dir=f", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body) + self.assertEqual(channel_asc.json_body["total"], number_rooms) + self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"])) + self._check_fields(channel_asc.json_body["rooms"]) + + # get list in backward direction + channel_desc = self.make_request( + "GET", + self.url + "?dir=b", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body) + self.assertEqual(channel_desc.json_body["total"], number_rooms) + self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"])) + self._check_fields(channel_desc.json_body["rooms"]) + + # test that both lists have different directions + for i in range(0, number_rooms): + self.assertEqual( + channel_asc.json_body["rooms"][i]["room_id"], + channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"], + ) + + def test_next_token(self) -> None: + """Testing that `next_token` appears at the right place""" + + number_rooms = 5 + self._create_destination_rooms(number_rooms) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), number_rooms) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=6", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), number_rooms) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=4", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 4) + self.assertEqual(channel.json_body["next_token"], "4") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=4", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_destination_rooms(self) -> None: + """Testing that request the list of rooms is successfully.""" + number_rooms = 3 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(number_rooms, len(channel.json_body["rooms"])) + self._check_fields(channel.json_body["rooms"]) + + def _create_destination_rooms(self, number_rooms: int) -> None: + """Create a number rooms for destination + + Args: + number_rooms: Number of rooms to be created + """ + for _ in range(0, number_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + self.get_success( + self.store.store_destination_rooms_entries((self.dest,), room_id, 1234) + ) + + def _check_fields(self, content: List[JsonDict]) -> None: + """Checks that the expected room attributes are present in content + + Args: + content: List that is checked for content + """ + for c in content: + self.assertIn("room_id", c) + self.assertIn("stream_ordering", c) -- cgit 1.5.1 From 95b3f952fa43e51feae166fa1678761c5e32d900 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 26 Jan 2022 12:02:54 +0000 Subject: Add a config flag to inhibit `M_USER_IN_USE` during registration (#11743) This is mostly motivated by the tchap use case, where usernames are automatically generated from the user's email address (in a way that allows figuring out the email address from the username). Therefore, it's an issue if we respond to requests on /register and /register/available with M_USER_IN_USE, because it can potentially leak email addresses (which include the user's real name and place of work). This commit adds a flag to inhibit the M_USER_IN_USE errors that are raised both by /register/available, and when providing a username early into the registration process. This error will still be raised if the user completes the registration process but the username conflicts. This is particularly useful when using modules (https://github.com/matrix-org/synapse/pull/11790 adds a module callback to set the username of users at registration) or SSO, since they can ensure the username is unique. More context is available in the PR that introduced this behaviour to synapse-dinsic: matrix-org/synapse-dinsic#48 - as well as the issue in the matrix-dinsic repo: matrix-org/matrix-dinsic#476 --- changelog.d/11743.feature | 1 + docs/sample_config.yaml | 10 ++++++++++ synapse/config/registration.py | 12 +++++++++++ synapse/handlers/register.py | 26 +++++++++++++----------- synapse/rest/client/register.py | 11 ++++++++++ tests/rest/client/test_register.py | 41 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11743.feature (limited to 'tests/rest') diff --git a/changelog.d/11743.feature b/changelog.d/11743.feature new file mode 100644 index 0000000000..9809f48b96 --- /dev/null +++ b/changelog.d/11743.feature @@ -0,0 +1 @@ +Add a config flag to inhibit M_USER_IN_USE during registration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1b86d0295d..b38e6d6c88 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1428,6 +1428,16 @@ account_threepid_delegates: # #auto_join_rooms_for_guests: false +# Whether to inhibit errors raised when registering a new account if the user ID +# already exists. If turned on, that requests to /register/available will always +# show a user ID as available, and Synapse won't raise an error when starting +# a registration with a user ID that already exists. However, Synapse will still +# raise an error if the registration completes and the username conflicts. +# +# Defaults to false. +# +#inhibit_user_in_use_error: true + ## Metrics ### diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7a059c6dec..ea9b50fe97 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -190,6 +190,8 @@ class RegistrationConfig(Config): # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") + self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False) + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -446,6 +448,16 @@ class RegistrationConfig(Config): # Defaults to true. # #auto_join_rooms_for_guests: false + + # Whether to inhibit errors raised when registering a new account if the user ID + # already exists. If turned on, that requests to /register/available will always + # show a user ID as available, and Synapse won't raise an error when starting + # a registration with a user ID that already exists. However, Synapse will still + # raise an error if the registration completes and the username conflicts. + # + # Defaults to false. + # + #inhibit_user_in_use_error: true """ % locals() ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f08a516a75..a719d5eef3 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -132,6 +132,7 @@ class RegistrationHandler: localpart: str, guest_access_token: Optional[str] = None, assigned_user_id: Optional[str] = None, + inhibit_user_in_use_error: bool = False, ) -> None: if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -171,21 +172,22 @@ class RegistrationHandler: users = await self.store.get_users_by_id_case_insensitive(user_id) if users: - if not guest_access_token: + if not inhibit_user_in_use_error and not guest_access_token: raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - user_data = await self.auth.get_user_by_access_token(guest_access_token) - if ( - not user_data.is_guest - or UserID.from_string(user_data.user_id).localpart != localpart - ): - raise AuthError( - 403, - "Cannot register taken user ID without valid guest " - "credentials for that user.", - errcode=Codes.FORBIDDEN, - ) + if guest_access_token: + user_data = await self.auth.get_user_by_access_token(guest_access_token) + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): + raise AuthError( + 403, + "Cannot register taken user ID without valid guest " + "credentials for that user.", + errcode=Codes.FORBIDDEN, + ) if guest_access_token is None: try: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 8b56c76aed..c59dae7c03 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -339,12 +339,19 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) + self.inhibit_user_in_use_error = ( + hs.config.registration.inhibit_user_in_use_error + ) + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) + if self.inhibit_user_in_use_error: + return 200, {"available": True} + ip = request.getClientIP() with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -422,6 +429,9 @@ class RegisterRestServlet(RestServlet): self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) + self._inhibit_user_in_use_error = ( + hs.config.registration.inhibit_user_in_use_error + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -564,6 +574,7 @@ class RegisterRestServlet(RestServlet): desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, + inhibit_user_in_use_error=self._inhibit_user_in_use_error, ) # Check if the user-interactive authentication flows are complete, if diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 6e7c0f11df..407dd32a73 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -726,6 +726,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) + @override_config( + { + "inhibit_user_in_use_error": True, + } + ) + def test_inhibit_user_in_use_error(self): + """Tests that the 'inhibit_user_in_use_error' configuration flag behaves + correctly. + """ + username = "arthur" + + # Manually register the user, so we know the test isn't passing because of a lack + # of clashing. + reg_handler = self.hs.get_registration_handler() + self.get_success(reg_handler.register_user(username)) + + # Check that /available correctly ignores the username provided despite the + # username being already registered. + channel = self.make_request("GET", "register/available?username=" + username) + self.assertEquals(200, channel.code, channel.result) + + # Test that when starting a UIA registration flow the request doesn't fail because + # of a conflicting username + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "foo"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Test that finishing the registration fails because of a conflicting username. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + class AccountValidityTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 2897fb6b4fb8bdaea0e919233d5ccaf5dea12742 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Jan 2022 08:27:04 -0500 Subject: Improvements to bundling aggregations. (#11815) This is some odds and ends found during the review of #11791 and while continuing to work in this code: * Return attrs classes instead of dictionaries from some methods to improve type safety. * Call `get_bundled_aggregations` fewer times. * Adds a missing assertion in the tests. * Do not return empty bundled aggregations for an event (preferring to not include the bundle at all, as the docstring states). --- changelog.d/11815.misc | 1 + synapse/events/utils.py | 57 ++++++++++++++------- synapse/handlers/room.py | 77 +++++++++++++++-------------- synapse/handlers/search.py | 45 ++++++++--------- synapse/handlers/sync.py | 3 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 39 +++++++++------ synapse/rest/client/room.py | 39 +++++++++------ synapse/rest/client/sync.py | 3 +- synapse/storage/databases/main/relations.py | 61 ++++++++++++++--------- synapse/storage/databases/main/stream.py | 22 ++++++--- tests/rest/client/test_relations.py | 2 +- 12 files changed, 212 insertions(+), 139 deletions(-) create mode 100644 changelog.d/11815.misc (limited to 'tests/rest') diff --git a/changelog.d/11815.misc b/changelog.d/11815.misc new file mode 100644 index 0000000000..83aa6d6eb0 --- /dev/null +++ b/changelog.d/11815.misc @@ -0,0 +1 @@ +Improve type safety of bundled aggregations code. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 918adeecf8..243696b357 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import re -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict @@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze from . import EventBase +if TYPE_CHECKING: + from synapse.storage.databases.main.relations import BundledAggregations + + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? JsonDict: """Serializes a single event. @@ -415,7 +429,7 @@ class EventClientSerializer: self, event: EventBase, time_now: int, - aggregations: JsonDict, + aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -427,13 +441,18 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Make a copy in-case the object is cached. - aggregations = aggregations.copy() + serialized_aggregations = {} + + if aggregations.annotations: + serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations + + if aggregations.references: + serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references - if RelationTypes.REPLACE in aggregations: + if aggregations.replace: # If there is an edit replace the content, preserving existing # relations. - edit = aggregations[RelationTypes.REPLACE] + edit = aggregations.replace # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -451,24 +470,28 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - aggregations[RelationTypes.REPLACE] = { + serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, } # If this event is the start of a thread, include a summary of the replies. - if RelationTypes.THREAD in aggregations: - # Serialize the latest thread event. - latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] - - # Don't bundle aggregations as this could recurse forever. - aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=None - ) + if aggregations.thread: + serialized_aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": self.serialize_event( + aggregations.thread.latest_event, time_now, bundle_aggregations=None + ), + "count": aggregations.thread.count, + "current_user_participated": aggregations.thread.current_user_participated, + } # Include the bundled aggregations in the event. - serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + if serialized_aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + serialized_aggregations + ) def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f963078e59..1420d67729 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -30,6 +30,7 @@ from typing import ( Tuple, ) +import attr from typing_extensions import TypedDict from synapse.api.constants import ( @@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -90,6 +92,17 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventContext: + events_before: List[EventBase] + event: EventBase + events_after: List[EventBase] + state: List[EventBase] + aggregations: Dict[str, BundledAggregations] + start: str + end: str + + class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1119,7 +1132,7 @@ class RoomContextHandler: limit: int, event_filter: Optional[Filter], use_admin_priviledge: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[EventContext]: """Retrieves events, pagination tokens and state around a given event in a room. @@ -1167,38 +1180,28 @@ class RoomContextHandler: results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) + events_before = results.events_before + events_after = results.events_after if event_filter: - results["events_before"] = await event_filter.filter( - results["events_before"] - ) - results["events_after"] = await event_filter.filter(results["events_after"]) + events_before = await event_filter.filter(events_before) + events_after = await event_filter.filter(events_after) - results["events_before"] = await filter_evts(results["events_before"]) - results["events_after"] = await filter_evts(results["events_after"]) + events_before = await filter_evts(events_before) + events_after = await filter_evts(events_after) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. - results["event"] = filtered[0] + event = filtered[0] # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - [results["event"]], user.to_string() + itertools.chain(events_before, (event,), events_after), + user.to_string(), ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_before"], user.to_string() - ) - ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_after"], user.to_string() - ) - ) - results["aggregations"] = aggregations - if results["events_after"]: - last_event_id = results["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event_id @@ -1206,9 +1209,9 @@ class RoomContextHandler: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], + events_before, + (event,), + events_after, ) ) else: @@ -1226,21 +1229,23 @@ class RoomContextHandler: if event_filter: state_events = await event_filter.filter(state_events) - results["state"] = await filter_evts(state_events) - # We use a dummy token here as we only care about the room portion of # the token, which we replace. token = StreamToken.START - results["start"] = await token.copy_and_replace( - "room_key", results["start"] - ).to_string(self.store) - - results["end"] = await token.copy_and_replace( - "room_key", results["end"] - ).to_string(self.store) - - return results + return EventContext( + events_before=events_before, + event=event, + events_after=events_after, + state=await filter_evts(state_events), + aggregations=aggregations, + start=await token.copy_and_replace("room_key", results.start).to_string( + self.store + ), + end=await token.copy_and_replace("room_key", results.end).to_string( + self.store + ), + ) class TimestampLookupHandler: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0b153a6822..02bb5ae72f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -361,36 +361,37 @@ class SearchHandler: logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), - len(res["events_after"]), + len(res.events_before), + len(res.events_after), ) - res["events_before"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_before"] + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before ) - res["events_after"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_after"] + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after ) - res["start"] = await now_token.copy_and_replace( - "room_key", res["start"] - ).to_string(self.store) - - res["end"] = await now_token.copy_and_replace( - "room_key", res["end"] - ).to_string(self.store) + context = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace( + "room_key", res.end + ).to_string(self.store), + } if include_profile: senders = { ev.sender - for ev in itertools.chain( - res["events_before"], [event], res["events_after"] - ) + for ev in itertools.chain(events_before, [event], events_after) } - if res["events_after"]: - last_event_id = res["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event.event_id @@ -402,7 +403,7 @@ class SearchHandler: last_event_id, state_filter ) - res["profile_info"] = { + context["profile_info"] = { s.state_key: { "displayname": s.content.get("displayname", None), "avatar_url": s.content.get("avatar_url", None), @@ -411,7 +412,7 @@ class SearchHandler: if s.type == EventTypes.Member and s.state_key in senders } - contexts[event.event_id] = res + contexts[event.event_id] = context else: contexts = {} @@ -421,10 +422,10 @@ class SearchHandler: for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now + context["events_before"], time_now # type: ignore[arg-type] ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now + context["events_after"], time_now # type: ignore[arg-type] ) state_results = {} diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7e2a892b63..c72ed7c290 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -100,7 +101,7 @@ class TimelineBatch: limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. - bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None + bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index dadfc57413..3df8452eec 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -455,7 +455,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results["events_before"] + self.storage, user_id, results.events_before ) the_events.append(notif_event) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index efe25fe7eb..5b706efbcf 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, @@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet): use_admin_priviledge=True, ) - if not results: + if not event_context: raise SynapseError( HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND ) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return HTTPStatus.OK, results diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90bb9142a0..90355e44b2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, limit, event_filter ) - if not results: + if not event_context: raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d20ae1421e..f9615da525 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace +from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet): def serialize( events: Iterable[EventBase], - aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + aggregations: Optional[Dict[str, BundledAggregations]] = None, ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2cb5d06c13..a9a5dd5f03 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,17 +13,7 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast import attr from frozendict import frozendict @@ -43,6 +33,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -51,6 +42,30 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + latest_event: EventBase + count: int + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore): async def _get_bundled_aggregation_for_event( self, event: EventBase, user_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[BundledAggregations]: """Generate bundled aggregations for an event. Note that this does not use a cache, but depends on cached methods. @@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore): # The bundled aggregations to include, a mapping of relation type to a # type-specific value. Some types include the direct return type here # while others need more processing during serialization. - aggregations: Dict[str, Any] = {} + aggregations = BundledAggregations() annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations.annotations = annotations.to_dict() references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() + aggregations.references = references.to_dict() edit = None if event.type == EventTypes.Message: edit = await self.get_applicable_edit(event_id, room_id) if edit: - aggregations[RelationTypes.REPLACE] = edit + aggregations.replace = edit # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: @@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore): event_id, room_id, user_id ) if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - "latest_event": latest_thread_event, - "count": thread_count, - "current_user_participated": participated, - } + aggregations.thread = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + current_user_participated=participated, + ) # Store the bundled aggregations in the event metadata for later use. return aggregations @@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore): self, events: Iterable[EventBase], user_id: str, - ) -> Dict[str, Dict[str, Any]]: + ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. Args: @@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore): results = {} for event in events: event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result is not None: + if event_result: results[event.event_id] = event_result return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 319464b1fa..a898f847e7 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -81,6 +81,14 @@ class _EventDictReturn: stream_ordering: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventsAround: + events_before: List[EventBase] + events_after: List[EventBase] + start: RoomStreamToken + end: RoomStreamToken + + def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], @@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, - ) -> dict: + ) -> _EventsAround: """Retrieve events and pagination tokens around a given event in a room. """ @@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): list(results["after"]["event_ids"]), get_prev_content=True ) - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } + return _EventsAround( + events_before=events_before, + events_after=events_after, + start=results["before"]["token"], + end=results["after"]["token"], + ) def _get_events_around_txn( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c9b220e73d..96ae7790bb 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - self._find_event_in_chunk(room_timeline["events"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included -- cgit 1.5.1 From bf60da1a60096fac5fb778b732ff2214862ac808 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 28 Jan 2022 14:41:33 +0000 Subject: Configurable limits on avatars (#11846) Only allow files which file size and content types match configured limits to be set as avatar. Most of the inspiration from the non-test code comes from matrix-org/synapse-dinsic#19 --- changelog.d/11846.feature | 1 + docs/sample_config.yaml | 14 ++++ synapse/config/server.py | 27 +++++++ synapse/handlers/profile.py | 67 ++++++++++++++++ synapse/handlers/room_member.py | 6 ++ tests/handlers/test_profile.py | 94 ++++++++++++++++++++++- tests/rest/client/test_profile.py | 156 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 363 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11846.feature (limited to 'tests/rest') diff --git a/changelog.d/11846.feature b/changelog.d/11846.feature new file mode 100644 index 0000000000..fcf6affdb5 --- /dev/null +++ b/changelog.d/11846.feature @@ -0,0 +1 @@ +Allow configuring a maximum file size as well as a list of allowed content types for avatars. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index abf28e4490..689b207fc0 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -471,6 +471,20 @@ limit_remote_rooms: # #allow_per_room_profiles: false +# The largest allowed file size for a user avatar. Defaults to no restriction. +# +# Note that user avatar changes will not work if this is set without +# using Synapse's media repository. +# +#max_avatar_size: 10M + +# The MIME types allowed for user avatars. Defaults to no restriction. +# +# Note that user avatar changes will not work if this is set without +# using Synapse's media repository. +# +#allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"] + # How long to keep redacted events in unredacted form in the database. After # this period redacted events get replaced with their redacted form in the DB. # diff --git a/synapse/config/server.py b/synapse/config/server.py index f200d0c1f1..a460cf25b4 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -489,6 +489,19 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + # The maximum size an avatar can have, in bytes. + self.max_avatar_size = config.get("max_avatar_size") + if self.max_avatar_size is not None: + self.max_avatar_size = self.parse_size(self.max_avatar_size) + + # The MIME types allowed for an avatar. + self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes") + if self.allowed_avatar_mimetypes and not isinstance( + self.allowed_avatar_mimetypes, + list, + ): + raise ConfigError("allowed_avatar_mimetypes must be a list") + self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])] # no_tls is not really supported any more, but let's grandfather it in @@ -1168,6 +1181,20 @@ class ServerConfig(Config): # #allow_per_room_profiles: false + # The largest allowed file size for a user avatar. Defaults to no restriction. + # + # Note that user avatar changes will not work if this is set without + # using Synapse's media repository. + # + #max_avatar_size: 10M + + # The MIME types allowed for user avatars. Defaults to no restriction. + # + # Note that user avatar changes will not work if this is set without + # using Synapse's media repository. + # + #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"] + # How long to keep redacted events in unredacted form in the database. After # this period redacted events get replaced with their redacted form in the DB. # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6b5a6ded8b..36e3ad2ba9 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -31,6 +31,8 @@ from synapse.types import ( create_requester, get_domain_from_id, ) +from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import parse_and_validate_mxc_uri if TYPE_CHECKING: from synapse.server import HomeServer @@ -64,6 +66,11 @@ class ProfileHandler: self.user_directory_handler = hs.get_user_directory_handler() self.request_ratelimiter = hs.get_request_ratelimiter() + self.max_avatar_size = hs.config.server.max_avatar_size + self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes + + self.server_name = hs.config.server.server_name + if hs.config.worker.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS @@ -286,6 +293,9 @@ class ProfileHandler: 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + if not await self.check_avatar_size_and_mime_type(new_avatar_url): + raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN) + avatar_url_to_set: Optional[str] = new_avatar_url if new_avatar_url == "": avatar_url_to_set = None @@ -307,6 +317,63 @@ class ProfileHandler: await self._update_join_states(requester, target_user) + @cached() + async def check_avatar_size_and_mime_type(self, mxc: str) -> bool: + """Check that the size and content type of the avatar at the given MXC URI are + within the configured limits. + + Args: + mxc: The MXC URI at which the avatar can be found. + + Returns: + A boolean indicating whether the file can be allowed to be set as an avatar. + """ + if not self.max_avatar_size and not self.allowed_avatar_mimetypes: + return True + + server_name, _, media_id = parse_and_validate_mxc_uri(mxc) + + if server_name == self.server_name: + media_info = await self.store.get_local_media(media_id) + else: + media_info = await self.store.get_cached_remote_media(server_name, media_id) + + if media_info is None: + # Both configuration options need to access the file's metadata, and + # retrieving remote avatars just for this becomes a bit of a faff, especially + # if e.g. the file is too big. It's also generally safe to assume most files + # used as avatar are uploaded locally, or if the upload didn't happen as part + # of a PUT request on /avatar_url that the file was at least previewed by the + # user locally (and therefore downloaded to the remote media cache). + logger.warning("Forbidding avatar change to %s: avatar not on server", mxc) + return False + + if self.max_avatar_size: + # Ensure avatar does not exceed max allowed avatar size + if media_info["media_length"] > self.max_avatar_size: + logger.warning( + "Forbidding avatar change to %s: %d bytes is above the allowed size " + "limit", + mxc, + media_info["media_length"], + ) + return False + + if self.allowed_avatar_mimetypes: + # Ensure the avatar's file type is allowed + if ( + self.allowed_avatar_mimetypes + and media_info["media_type"] not in self.allowed_avatar_mimetypes + ): + logger.warning( + "Forbidding avatar change to %s: mimetype %s not allowed", + mxc, + media_info["media_type"], + ) + return False + + return True + async def on_profile_query(self, args: JsonDict) -> JsonDict: """Handles federation profile query requests.""" diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6aa910dd10..3dd5e1b6e4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -590,6 +590,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): errcode=Codes.BAD_JSON, ) + if "avatar_url" in content: + if not await self.profile_handler.check_avatar_size_and_mime_type( + content["avatar_url"], + ): + raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN) + # The event content should *not* include the authorising user as # it won't be properly signed. Strip it out since it might come # back from a client updating a display name / avatar. diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index c153018fd8..60235e5699 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Any, Dict from unittest.mock import Mock import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin +from synapse.server import HomeServer from synapse.types import UserID from tests import unittest @@ -46,7 +47,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: HomeServer): self.store = hs.get_datastore() self.frank = UserID.from_string("@1234abcd:test") @@ -248,3 +249,92 @@ class ProfileTestCase(unittest.HomeserverTestCase): ), SynapseError, ) + + def test_avatar_constraints_no_config(self): + """Tests that the method to check an avatar against configured constraints skips + all of its check if no constraint is configured. + """ + # The first check that's done by this method is whether the file exists; if we + # don't get an error on a non-existing file then it means all of the checks were + # successfully skipped. + res = self.get_success( + self.handler.check_avatar_size_and_mime_type("mxc://test/unknown_file") + ) + self.assertTrue(res) + + @unittest.override_config({"max_avatar_size": 50}) + def test_avatar_constraints_missing(self): + """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't + be found. + """ + res = self.get_success( + self.handler.check_avatar_size_and_mime_type("mxc://test/unknown_file") + ) + self.assertFalse(res) + + @unittest.override_config({"max_avatar_size": 50}) + def test_avatar_constraints_file_size(self): + """Tests that a file that's above the allowed file size is forbidden but one + that's below it is allowed. + """ + self._setup_local_files( + { + "small": {"size": 40}, + "big": {"size": 60}, + } + ) + + res = self.get_success( + self.handler.check_avatar_size_and_mime_type("mxc://test/small") + ) + self.assertTrue(res) + + res = self.get_success( + self.handler.check_avatar_size_and_mime_type("mxc://test/big") + ) + self.assertFalse(res) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_avatar_constraint_mime_type(self): + """Tests that a file with an unauthorised MIME type is forbidden but one with + an authorised content type is allowed. + """ + self._setup_local_files( + { + "good": {"mimetype": "image/png"}, + "bad": {"mimetype": "application/octet-stream"}, + } + ) + + res = self.get_success( + self.handler.check_avatar_size_and_mime_type("mxc://test/good") + ) + self.assertTrue(res) + + res = self.get_success( + self.handler.check_avatar_size_and_mime_type("mxc://test/bad") + ) + self.assertFalse(res) + + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + """Stores metadata about files in the database. + + Args: + names_and_props: A dictionary with one entry per file, with the key being the + file's name, and the value being a dictionary of properties. Supported + properties are "mimetype" (for the file's type) and "size" (for the + file's size). + """ + store = self.hs.get_datastore() + + for name, props in names_and_props.items(): + self.get_success( + store.store_local_media( + media_id=name, + media_type=props.get("mimetype", "image/png"), + time_now_ms=self.clock.time_msec(), + upload_name=None, + media_length=props.get("size", 50), + user_id=UserID.from_string("@rin:test"), + ) + ) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 2860579c2e..ead883ded8 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -13,8 +13,12 @@ # limitations under the License. """Tests REST events for /profile paths.""" +from typing import Any, Dict + +from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room +from synapse.types import UserID from tests import unittest @@ -25,6 +29,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): admin.register_servlets_for_client_rest_resource, login.register_servlets, profile.register_servlets, + room.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -150,6 +155,157 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) return channel.json_body.get("avatar_url") + @unittest.override_config({"max_avatar_size": 50}) + def test_avatar_size_limit_global(self): + """Tests that the maximum size limit for avatars is enforced when updating a + global profile. + """ + self._setup_local_files( + { + "small": {"size": 40}, + "big": {"size": 60}, + } + ) + + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://test/big"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body + ) + + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://test/small"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + @unittest.override_config({"max_avatar_size": 50}) + def test_avatar_size_limit_per_room(self): + """Tests that the maximum size limit for avatars is enforced when updating a + per-room profile. + """ + self._setup_local_files( + { + "small": {"size": 40}, + "big": {"size": 60}, + } + ) + + room_id = self.helper.create_room_as(tok=self.owner_tok) + + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.member/{self.owner}", + content={"membership": "join", "avatar_url": "mxc://test/big"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body + ) + + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.member/{self.owner}", + content={"membership": "join", "avatar_url": "mxc://test/small"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_avatar_allowed_mime_type_global(self): + """Tests that the MIME type whitelist for avatars is enforced when updating a + global profile. + """ + self._setup_local_files( + { + "good": {"mimetype": "image/png"}, + "bad": {"mimetype": "application/octet-stream"}, + } + ) + + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://test/bad"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body + ) + + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://test/good"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_avatar_allowed_mime_type_per_room(self): + """Tests that the MIME type whitelist for avatars is enforced when updating a + per-room profile. + """ + self._setup_local_files( + { + "good": {"mimetype": "image/png"}, + "bad": {"mimetype": "application/octet-stream"}, + } + ) + + room_id = self.helper.create_room_as(tok=self.owner_tok) + + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.member/{self.owner}", + content={"membership": "join", "avatar_url": "mxc://test/bad"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body + ) + + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.member/{self.owner}", + content={"membership": "join", "avatar_url": "mxc://test/good"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + """Stores metadata about files in the database. + + Args: + names_and_props: A dictionary with one entry per file, with the key being the + file's name, and the value being a dictionary of properties. Supported + properties are "mimetype" (for the file's type) and "size" (for the + file's size). + """ + store = self.hs.get_datastore() + + for name, props in names_and_props.items(): + self.get_success( + store.store_local_media( + media_id=name, + media_type=props.get("mimetype", "image/png"), + time_now_ms=self.clock.time_msec(), + upload_name=None, + media_length=props.get("size", 50), + user_id=UserID.from_string("@rin:test"), + ) + ) + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 901b264c0c88f39cbfb8b2229e0dc57968882658 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 31 Jan 2022 20:20:05 +0100 Subject: Add type hints to `tests/rest/admin` (#11851) --- changelog.d/11851.misc | 1 + mypy.ini | 3 - tests/rest/admin/test_admin.py | 134 +++++--------- tests/rest/admin/test_user.py | 262 ++++++++++++++-------------- tests/rest/admin/test_username_available.py | 16 +- 5 files changed, 184 insertions(+), 232 deletions(-) create mode 100644 changelog.d/11851.misc (limited to 'tests/rest') diff --git a/changelog.d/11851.misc b/changelog.d/11851.misc new file mode 100644 index 0000000000..ccc3ec3482 --- /dev/null +++ b/changelog.d/11851.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/admin`. diff --git a/mypy.ini b/mypy.ini index 85fa22d28f..2884078d0a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -77,9 +77,6 @@ exclude = (?x) |tests/push/test_http.py |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py - |tests/rest/admin/test_admin.py - |tests/rest/admin/test_user.py - |tests/rest/admin/test_username_available.py |tests/rest/client/test_account.py |tests/rest/client/test_events.py |tests/rest/client/test_filter.py diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 3adadcb46b..849d00ab4d 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import urllib.parse from http import HTTPStatus -from unittest.mock import Mock +from typing import List -from twisted.internet.defer import Deferred +from parameterized import parameterized + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.http.server import JsonResource -from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet from synapse.rest.client import groups, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -33,12 +35,12 @@ from tests.test_utils import SMALL_PNG class VersionTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/server_version" - def create_test_resource(self): + def create_test_resource(self) -> JsonResource: resource = JsonResource(self.hs) VersionServlet(self.hs).register(resource) return resource - def test_version_string(self): + def test_version_string(self) -> None: channel = self.make_request("GET", self.url, shorthand=False) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) @@ -54,14 +56,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): groups.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") - def test_delete_group(self): + def test_delete_group(self) -> None: # Create a new group channel = self.make_request( "POST", @@ -112,7 +114,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - def _check_group(self, group_id, expect_code): + def _check_group(self, group_id: str, expect_code: int) -> None: """Assert that trying to fetch the given group results in the given HTTP status code """ @@ -124,7 +126,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): self.assertEqual(expect_code, channel.code, msg=channel.json_body) - def _get_groups_user_is_in(self, access_token): + def _get_groups_user_is_in(self, access_token: str) -> List[str]: """Returns the list of groups the user is in (given their access token)""" channel = self.make_request("GET", b"/joined_groups", access_token=access_token) @@ -143,59 +145,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] self.upload_resource = self.media_repo.children[b"upload"] - def make_homeserver(self, reactor, clock): - - self.fetches = [] - - async def get_file(destination, path, output_stream, args=None, max_size=None): - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ - - def write_to(r): - data, response = r - output_stream.write(data) - return response - - d = Deferred() - d.addCallback(write_to) - self.fetches.append((d, destination, path, args)) - return await make_deferred_yieldable(d) - - client = Mock() - client.get_file = get_file - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - - config = self.default_config() - config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} - config["max_image_pixels"] = 2000000 - - provider_config = { - "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - config["media_storage_providers"] = [provider_config] - - hs = self.setup_test_homeserver(config=config, federation_http_client=client) - - return hs - - def _ensure_quarantined(self, admin_user_tok, server_and_media_id): + def _ensure_quarantined( + self, admin_user_tok: str, server_and_media_id: str + ) -> None: """Ensure a piece of media is quarantined when trying to access it.""" channel = make_request( self.reactor, @@ -216,12 +174,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ), ) - def test_quarantine_media_requires_admin(self): + @parameterized.expand( + [ + # Attempt quarantine media APIs as non-admin + "/_synapse/admin/v1/media/quarantine/example.org/abcde12345", + # And the roomID/userID endpoint + "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine", + ] + ) + def test_quarantine_media_requires_admin(self, url: str) -> None: self.register_user("nonadmin", "pass", admin=False) non_admin_user_tok = self.login("nonadmin", "pass") - # Attempt quarantine media APIs as non-admin - url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" channel = self.make_request( "POST", url.encode("ascii"), @@ -235,22 +199,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): msg="Expected forbidden on quarantining media as a non-admin", ) - # And the roomID/userID endpoint - url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" - channel = self.make_request( - "POST", - url.encode("ascii"), - access_token=non_admin_user_tok, - ) - - # Expect a forbidden error - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg="Expected forbidden on quarantining media as a non-admin", - ) - - def test_quarantine_media_by_id(self): + def test_quarantine_media_by_id(self) -> None: self.register_user("id_admin", "pass", admin=True) admin_user_tok = self.login("id_admin", "pass") @@ -295,7 +244,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) - def test_quarantine_all_media_in_room(self, override_url_template=None): + @parameterized.expand( + [ + # regular API path + "/_synapse/admin/v1/room/%s/media/quarantine", + # deprecated API path + "/_synapse/admin/v1/quarantine_media/%s", + ] + ) + def test_quarantine_all_media_in_room(self, url: str) -> None: self.register_user("room_admin", "pass", admin=True) admin_user_tok = self.login("room_admin", "pass") @@ -333,16 +290,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): tok=non_admin_user_tok, ) - # Quarantine all media in the room - if override_url_template: - url = override_url_template % urllib.parse.quote(room_id) - else: - url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( - room_id - ) channel = self.make_request( "POST", - url, + url % urllib.parse.quote(room_id), access_token=admin_user_tok, ) self.pump(1.0) @@ -359,11 +309,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_2) - def test_quarantine_all_media_in_room_deprecated_api_path(self): - # Perform the above test with the deprecated API path - self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s") - - def test_quarantine_all_media_by_user(self): + def test_quarantine_all_media_by_user(self) -> None: self.register_user("user_admin", "pass", admin=True) admin_user_tok = self.login("user_admin", "pass") @@ -401,7 +347,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_2) - def test_cannot_quarantine_safe_media(self): + def test_cannot_quarantine_safe_media(self) -> None: self.register_user("user_admin", "pass", admin=True) admin_user_tok = self.login("user_admin", "pass") @@ -475,7 +421,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -488,7 +434,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" self.url_status = "/_synapse/admin/v1/purge_history_status/" - def test_purge_history(self): + def test_purge_history(self) -> None: """ Simple test of purge history API. Test only that is is possible to call, get status HTTPStatus.OK and purge_id. diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9711405735..272637e965 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -23,13 +23,17 @@ from unittest.mock import Mock, patch from parameterized import parameterized, parameterized_class +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client import devices, login, logout, profile, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.server import HomeServer from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -44,7 +48,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = "/_synapse/admin/v1/register" @@ -61,12 +65,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.config.registration.registration_shared_secret = "shared" - self.hs.get_media_repository = Mock() - self.hs.get_deactivate_account_handler = Mock() + self.hs.get_media_repository = Mock() # type: ignore[assignment] + self.hs.get_deactivate_account_handler = Mock() # type: ignore[assignment] return self.hs - def test_disabled(self): + def test_disabled(self) -> None: """ If there is no shared secret, registration through this method will be prevented. @@ -80,7 +84,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "Shared secret registration is not enabled", channel.json_body["error"] ) - def test_get_nonce(self): + def test_get_nonce(self) -> None: """ Calling GET on the endpoint will return a randomised nonce, using the homeserver's secrets provider. @@ -93,7 +97,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body, {"nonce": "abcd"}) - def test_expired_nonce(self): + def test_expired_nonce(self) -> None: """ Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). @@ -118,7 +122,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) - def test_register_incorrect_nonce(self): + def test_register_incorrect_nonce(self) -> None: """ Only the provided nonce can be used, as it's checked in the MAC. """ @@ -141,7 +145,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) - def test_register_correct_nonce(self): + def test_register_correct_nonce(self) -> None: """ When the correct nonce is provided, and the right key is provided, the user is registered. @@ -168,7 +172,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) - def test_nonce_reuse(self): + def test_nonce_reuse(self) -> None: """ A valid unrecognised nonce. """ @@ -197,14 +201,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) - def test_missing_parts(self): + def test_missing_parts(self) -> None: """ Synapse will complain if you don't give nonce, username, password, and mac. Admin and user_types are optional. Additional checks are done for length and type. """ - def nonce(): + def nonce() -> str: channel = self.make_request("GET", self.url) return channel.json_body["nonce"] @@ -297,7 +301,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) - def test_displayname(self): + def test_displayname(self) -> None: """ Test that displayname of new user is set """ @@ -400,7 +404,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} ) - def test_register_mau_limit_reached(self): + def test_register_mau_limit_reached(self) -> None: """ Check we can register a user via the shared secret registration API even if the MAU limit is reached. @@ -450,13 +454,13 @@ class UsersListTestCase(unittest.HomeserverTestCase): ] url = "/_synapse/admin/v2/users" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list users without authentication. """ @@ -465,7 +469,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -477,7 +481,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_all_users(self): + def test_all_users(self) -> None: """ List all users, including deactivated users. """ @@ -497,7 +501,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that all fields are available self._check_fields(channel.json_body["users"]) - def test_search_term(self): + def test_search_term(self) -> None: """Test that searching for a users works correctly""" def _search_test( @@ -505,7 +509,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): search_term: str, search_field: Optional[str] = "name", expected_http_code: Optional[int] = HTTPStatus.OK, - ): + ) -> None: """Search for a user and check that the returned user's id is a match Args: @@ -575,7 +579,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -640,7 +644,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of users with limit """ @@ -661,7 +665,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["next_token"], "5") self._check_fields(channel.json_body["users"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of users with a defined starting point (from) """ @@ -682,7 +686,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["users"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of users with a defined starting point and limit """ @@ -703,7 +707,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 10) self._check_fields(channel.json_body["users"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -765,7 +769,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_order_by(self): + def test_order_by(self) -> None: """ Testing order list with parameter `order_by` """ @@ -843,7 +847,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_user_list: List[str], order_by: Optional[str], dir: Optional[str] = None, - ): + ) -> None: """Request the list of users in a certain order. Assert that order is what we expect Args: @@ -870,7 +874,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(expected_user_list, returned_order) self._check_fields(channel.json_body["users"]) - def _check_fields(self, content: List[JsonDict]): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -886,7 +890,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertIn("avatar_url", u) self.assertIn("creation_ts", u) - def _create_users(self, number_users: int): + def _create_users(self, number_users: int) -> None: """ Create a number of users Args: @@ -908,7 +912,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -931,7 +935,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to deactivate users without authentication. """ @@ -940,7 +944,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + def test_requester_is_not_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -961,7 +965,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -975,7 +979,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_erase_is_not_bool(self): + def test_erase_is_not_bool(self) -> None: """ If parameter `erase` is not boolean, return an error """ @@ -990,7 +994,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -1001,7 +1005,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) - def test_deactivate_user_erase_true(self): + def test_deactivate_user_erase_true(self) -> None: """ Test deactivating a user and set `erase` to `true` """ @@ -1046,7 +1050,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", True) - def test_deactivate_user_erase_false(self): + def test_deactivate_user_erase_false(self) -> None: """ Test deactivating a user and set `erase` to `false` """ @@ -1091,7 +1095,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", False) - def test_deactivate_user_erase_true_no_profile(self): + def test_deactivate_user_erase_true_no_profile(self) -> None: """ Test deactivating a user and set `erase` to `true` if user has no profile information (stored in the database table `profiles`). @@ -1162,7 +1166,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.auth_handler = hs.get_auth_handler() @@ -1185,7 +1189,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.url_prefix = "/_synapse/admin/v2/users/%s" self.url_other_user = self.url_prefix % self.other_user - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -1210,7 +1214,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -1224,7 +1228,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -1319,7 +1323,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_get_user(self): + def test_get_user(self) -> None: """ Test a simple get of a user. """ @@ -1334,7 +1338,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) - def test_create_server_admin(self): + def test_create_server_admin(self) -> None: """ Check that a new admin user is created successfully. """ @@ -1383,7 +1387,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) - def test_create_user(self): + def test_create_user(self) -> None: """ Check that a new regular user is created successfully. """ @@ -1450,7 +1454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} ) - def test_create_user_mau_limit_reached_active_admin(self): + def test_create_user_mau_limit_reached_active_admin(self) -> None: """ Check that an admin can register a new user via the admin API even if the MAU limit is reached. @@ -1496,7 +1500,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} ) - def test_create_user_mau_limit_reached_passive_admin(self): + def test_create_user_mau_limit_reached_passive_admin(self) -> None: """ Check that an admin can register a new user via the admin API even if the MAU limit is reached. @@ -1541,7 +1545,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "public_baseurl": "https://example.com", } ) - def test_create_user_email_notif_for_new_users(self): + def test_create_user_email_notif_for_new_users(self) -> None: """ Check that a new regular user is created successfully and got an email pusher. @@ -1584,7 +1588,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "public_baseurl": "https://example.com", } ) - def test_create_user_email_no_notif_for_new_users(self): + def test_create_user_email_no_notif_for_new_users(self) -> None: """ Check that a new regular user is created successfully and got not an email pusher. @@ -1615,7 +1619,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 0) - def test_set_password(self): + def test_set_password(self) -> None: """ Test setting a new password for another user. """ @@ -1631,7 +1635,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) - def test_set_displayname(self): + def test_set_displayname(self) -> None: """ Test setting the displayname of another user. """ @@ -1659,7 +1663,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) - def test_set_threepid(self): + def test_set_threepid(self) -> None: """ Test setting threepid for an other user. """ @@ -1740,7 +1744,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) - def test_set_duplicate_threepid(self): + def test_set_duplicate_threepid(self) -> None: """ Test setting the same threepid for a second user. First user loses and second user gets mapping of this threepid. @@ -1827,7 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) - def test_set_external_id(self): + def test_set_external_id(self) -> None: """ Test setting external id for an other user. """ @@ -1925,7 +1929,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) - def test_set_duplicate_external_id(self): + def test_set_duplicate_external_id(self) -> None: """ Test that setting the same external id for a second user fails and external id from user must not be changed. @@ -2048,7 +2052,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self._check_fields(channel.json_body) - def test_deactivate_user(self): + def test_deactivate_user(self) -> None: """ Test deactivating another user. """ @@ -2113,7 +2117,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("password_hash", channel.json_body) @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) - def test_change_name_deactivate_user_user_directory(self): + def test_change_name_deactivate_user_user_directory(self) -> None: """ Test change profile information of a deactivated user and check that it does not appear in user directory @@ -2156,7 +2160,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): profile = self.get_success(self.store.get_user_in_directory(self.other_user)) self.assertIsNone(profile) - def test_reactivate_user(self): + def test_reactivate_user(self) -> None: """ Test reactivating another user. """ @@ -2189,7 +2193,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("password_hash", channel.json_body) @override_config({"password_config": {"localdb_enabled": False}}) - def test_reactivate_user_localdb_disabled(self): + def test_reactivate_user_localdb_disabled(self) -> None: """ Test reactivating another user when using SSO. """ @@ -2223,7 +2227,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("password_hash", channel.json_body) @override_config({"password_config": {"enabled": False}}) - def test_reactivate_user_password_disabled(self): + def test_reactivate_user_password_disabled(self) -> None: """ Test reactivating another user when using SSO. """ @@ -2256,7 +2260,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # This key was removed intentionally. Ensure it is not accidentally re-included. self.assertNotIn("password_hash", channel.json_body) - def test_set_user_as_admin(self): + def test_set_user_as_admin(self) -> None: """ Test setting the admin flag on a user. """ @@ -2284,7 +2288,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) - def test_set_user_type(self): + def test_set_user_type(self) -> None: """ Test changing user type. """ @@ -2335,7 +2339,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) - def test_accidental_deactivation_prevention(self): + def test_accidental_deactivation_prevention(self) -> None: """ Ensure an account can't accidentally be deactivated by using a str value for the deactivated body parameter @@ -2418,7 +2422,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # This key was removed intentionally. Ensure it is not accidentally re-included. self.assertNotIn("password_hash", channel.json_body) - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: JsonDict) -> None: """Checks that the expected user attributes are present in content Args: @@ -2448,7 +2452,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2457,7 +2461,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list rooms of an user without authentication. """ @@ -2466,7 +2470,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -2481,7 +2485,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns an empty list """ @@ -2496,7 +2500,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list """ @@ -2512,7 +2516,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) - def test_no_memberships(self): + def test_no_memberships(self) -> None: """ Tests that a normal lookup for rooms is successfully if user has no memberships @@ -2528,7 +2532,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) - def test_get_rooms(self): + def test_get_rooms(self) -> None: """ Tests that a normal lookup for rooms is successfully """ @@ -2549,7 +2553,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) - def test_get_rooms_with_nonlocal_user(self): + def test_get_rooms_with_nonlocal_user(self) -> None: """ Tests that a normal lookup for rooms is successful with a non-local user """ @@ -2604,7 +2608,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2615,7 +2619,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list pushers of an user without authentication. """ @@ -2624,7 +2628,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -2639,7 +2643,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -2653,7 +2657,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -2668,7 +2672,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) - def test_get_pushers(self): + def test_get_pushers(self) -> None: """ Tests that a normal lookup for pushers is successfully """ @@ -2732,7 +2736,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) @@ -2746,7 +2750,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["GET", "DELETE"]) - def test_no_auth(self, method: str): + def test_no_auth(self, method: str) -> None: """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) @@ -2754,7 +2758,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """If the user is not a server admin, an error is returned.""" other_user_token = self.login("user", "pass") @@ -2768,7 +2772,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) - def test_user_does_not_exist(self, method: str): + def test_user_does_not_exist(self, method: str) -> None: """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( @@ -2781,7 +2785,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) - def test_user_is_not_local(self, method: str): + def test_user_is_not_local(self, method: str) -> None: """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" @@ -2794,7 +2798,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) - def test_limit_GET(self): + def test_limit_GET(self) -> None: """Testing list of media with limit""" number_media = 20 @@ -2813,7 +2817,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["media"]) - def test_limit_DELETE(self): + def test_limit_DELETE(self) -> None: """Testing delete of media with limit""" number_media = 20 @@ -2830,7 +2834,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) - def test_from_GET(self): + def test_from_GET(self) -> None: """Testing list of media with a defined starting point (from)""" number_media = 20 @@ -2849,7 +2853,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def test_from_DELETE(self): + def test_from_DELETE(self) -> None: """Testing delete of media with a defined starting point (from)""" number_media = 20 @@ -2866,7 +2870,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) - def test_limit_and_from_GET(self): + def test_limit_and_from_GET(self) -> None: """Testing list of media with a defined starting point and limit""" number_media = 20 @@ -2885,7 +2889,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_limit_and_from_DELETE(self): + def test_limit_and_from_DELETE(self) -> None: """Testing delete of media with a defined starting point and limit""" number_media = 20 @@ -2903,7 +2907,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["deleted_media"]), 10) @parameterized.expand(["GET", "DELETE"]) - def test_invalid_parameter(self, method: str): + def test_invalid_parameter(self, method: str) -> None: """If parameters are invalid, an error is returned.""" # unkown order_by channel = self.make_request( @@ -2945,7 +2949,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place @@ -3010,7 +3014,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_user_has_no_media_GET(self): + def test_user_has_no_media_GET(self) -> None: """ Tests that a normal lookup for media is successfully if user has no media created @@ -3026,7 +3030,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) - def test_user_has_no_media_DELETE(self): + def test_user_has_no_media_DELETE(self) -> None: """ Tests that a delete is successful if user has no media """ @@ -3041,7 +3045,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) - def test_get_media(self): + def test_get_media(self) -> None: """Tests that a normal lookup for media is successful""" number_media = 5 @@ -3060,7 +3064,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def test_delete_media(self): + def test_delete_media(self) -> None: """Tests that a normal delete of media is successful""" number_media = 5 @@ -3089,7 +3093,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): for local_path in local_paths: self.assertFalse(os.path.exists(local_path)) - def test_order_by(self): + def test_order_by(self) -> None: """ Testing order list with parameter `order_by` """ @@ -3252,7 +3256,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): return media_id - def _check_fields(self, content: List[JsonDict]): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -3272,7 +3276,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): expected_media_list: List[str], order_by: Optional[str], dir: Optional[str] = None, - ): + ) -> None: """Request the list of media in a certain order. Assert that order is what we expect Args: @@ -3312,7 +3316,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): logout.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -3331,14 +3335,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) return channel.json_body["access_token"] - def test_no_auth(self): + def test_no_auth(self) -> None: """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_not_admin(self): + def test_not_admin(self) -> None: """Try to login as a user as a non-admin user.""" channel = self.make_request( "POST", self.url, b"{}", access_token=self.other_user_tok @@ -3346,7 +3350,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) - def test_send_event(self): + def test_send_event(self) -> None: """Test that sending event as a user works.""" # Create a room. room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) @@ -3360,7 +3364,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): event = self.get_success(self.store.get_event(event_id)) self.assertEqual(event.sender, self.other_user) - def test_devices(self): + def test_devices(self) -> None: """Tests that logging in as a user doesn't create a new device for them.""" # Login in as the user self._get_token() @@ -3374,7 +3378,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) - def test_logout(self): + def test_logout(self) -> None: """Test that calling `/logout` with the token works.""" # Login in as the user puppet_token = self._get_token() @@ -3397,7 +3401,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_user_logout_all(self): + def test_user_logout_all(self) -> None: """Tests that the target user calling `/logout/all` does *not* expire the token. """ @@ -3424,7 +3428,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) - def test_admin_logout_all(self): + def test_admin_logout_all(self) -> None: """Tests that the admin user calling `/logout/all` does expire the token. """ @@ -3464,7 +3468,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "form_secret": "123secret", } ) - def test_consent(self): + def test_consent(self) -> None: """Test that sending a message is not subject to the privacy policies.""" # Have the admin user accept the terms. self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0")) @@ -3492,7 +3496,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): @override_config( {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0} ) - def test_mau_limit(self): + def test_mau_limit(self) -> None: # Create a room as the admin user. This will bump the monthly active users to 1. room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -3524,14 +3528,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") - self.url = self.url_prefix % self.other_user + self.url = self.url_prefix % self.other_user # type: ignore[attr-defined] - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get information of an user without authentication. """ @@ -3539,7 +3543,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + def test_requester_is_not_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -3554,11 +3558,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ - url = self.url_prefix % "@unknown_person:unknown_domain" + url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined] channel = self.make_request( "GET", @@ -3568,7 +3572,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) - def test_get_whois_admin(self): + def test_get_whois_admin(self) -> None: """ The lookup should succeed for an admin. """ @@ -3581,7 +3585,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - def test_get_whois_user(self): + def test_get_whois_user(self) -> None: """ The lookup should succeed for a normal user looking up their own information. """ @@ -3604,7 +3608,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -3617,7 +3621,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["POST", "DELETE"]) - def test_no_auth(self, method: str): + def test_no_auth(self, method: str) -> None: """ Try to get information of an user without authentication. """ @@ -3626,7 +3630,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) - def test_requester_is_not_admin(self, method: str): + def test_requester_is_not_admin(self, method: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -3637,7 +3641,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) - def test_user_is_not_local(self, method: str): + def test_user_is_not_local(self, method: str) -> None: """ Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -3646,7 +3650,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - def test_success(self): + def test_success(self) -> None: """ Shadow-banning should succeed for an admin. """ @@ -3682,7 +3686,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -3695,7 +3699,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["GET", "POST", "DELETE"]) - def test_no_auth(self, method: str): + def test_no_auth(self, method: str) -> None: """ Try to get information of a user without authentication. """ @@ -3705,7 +3709,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -3721,7 +3725,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) - def test_user_does_not_exist(self, method: str): + def test_user_does_not_exist(self, method: str) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -3743,7 +3747,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ("DELETE", "Only local users can be ratelimited"), ] ) - def test_user_is_not_local(self, method: str, error_msg: str): + def test_user_is_not_local(self, method: str, error_msg: str) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -3760,7 +3764,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -3808,7 +3812,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_return_zero_when_null(self): + def test_return_zero_when_null(self) -> None: """ If values in database are `null` API should return an int `0` """ @@ -3834,7 +3838,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) - def test_success(self): + def test_success(self) -> None: """ Rate-limiting (set/update/delete) should succeed for an admin. """ @@ -3908,7 +3912,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs) -> None: + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 7978626e71..b21f6d4689 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -14,9 +14,13 @@ from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,11 +32,11 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): ] url = "/_synapse/admin/v1/username_available" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - async def check_username(username): + async def check_username(username: str) -> bool: if username == "allowed": return True raise SynapseError( @@ -44,24 +48,24 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): handler = self.hs.get_registration_handler() handler.check_username = check_username - def test_username_available(self): + def test_username_available(self) -> None: """ The endpoint should return a HTTPStatus.OK response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") - channel = self.make_request("GET", url, None, self.admin_user_tok) + channel = self.make_request("GET", url, access_token=self.admin_user_tok) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["available"]) - def test_username_unavailable(self): + def test_username_unavailable(self) -> None: """ The endpoint should return a HTTPStatus.OK response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") - channel = self.make_request("GET", url, None, self.admin_user_tok) + channel = self.make_request("GET", url, access_token=self.admin_user_tok) self.assertEqual( HTTPStatus.BAD_REQUEST, -- cgit 1.5.1