From f6767abc054f3461cd9a70ba096fcf9a8e640edb Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 8 Jul 2021 10:57:13 -0500 Subject: Remove functionality associated with unused historical stats tables (#9721) Fixes #9602 --- tests/rest/admin/test_room.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests/rest') diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ee071c2477..959d3cea77 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1753,7 +1753,6 @@ PURGE_TABLES = [ "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", -- cgit 1.5.1 From 89cfc3dd9849b0580146151098ad039a7680c63f Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 12:43:15 +0200 Subject: [pyupgrade] `tests/` (#10347) --- changelog.d/10347.misc | 1 + tests/config/test_load.py | 4 ++-- tests/handlers/test_profile.py | 2 +- .../http/federation/test_matrix_federation_agent.py | 2 +- tests/http/test_fedclient.py | 8 +++----- tests/replication/_base.py | 6 +++--- tests/replication/test_multi_media_repo.py | 4 ++-- tests/replication/test_sharded_event_persister.py | 6 +++--- tests/rest/admin/test_admin.py | 6 ++---- tests/rest/admin/test_room.py | 20 ++++++++++---------- tests/rest/client/v1/test_rooms.py | 14 +++++++------- tests/rest/client/v2_alpha/test_relations.py | 2 +- tests/rest/client/v2_alpha/test_report_event.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_profile.py | 12 ++---------- tests/storage/test_purge.py | 2 +- tests/storage/test_room.py | 2 +- tests/test_types.py | 4 +--- tests/unittest.py | 2 +- 20 files changed, 45 insertions(+), 58 deletions(-) create mode 100644 changelog.d/10347.misc (limited to 'tests/rest') diff --git a/changelog.d/10347.misc b/changelog.d/10347.misc new file mode 100644 index 0000000000..b2275a1350 --- /dev/null +++ b/changelog.d/10347.misc @@ -0,0 +1 @@ +Run `pyupgrade` on the codebase. \ No newline at end of file diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ebe2c05165..903c69127d 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) @@ -120,7 +120,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def generate_config_and_remove_lines_containing(self, needle): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: contents = f.readlines() contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index cdb41101b3..2928c4f48c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -103,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - (self.get_success(self.store.get_profile_displayname(self.frank.localpart))) + self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) def test_set_my_name_if_disabled(self): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index e45980316b..a37bce08c3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -273,7 +273,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode("ascii")) + request.write(b'{ "a": 1 }') request.finish() self.reactor.pump((0.1,)) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ed9a884d76..d9a8b077d3 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -102,7 +102,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode("ascii") + res_json = b'{ "a": 1 }' protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -339,10 +339,8 @@ class FederationClientTests(HomeserverTestCase): # Send it the HTTP response client.dataReceived( - ( - b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n" - ) + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" ) # Push by enough to time it out diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 624bd1b927..386ea70a25 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -550,12 +550,12 @@ class FakeRedisPubSubProtocol(Protocol): if obj is None: return "$-1\r\n" if isinstance(obj, str): - return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + return f"${len(obj)}\r\n{obj}\r\n" if isinstance(obj, int): - return ":{val}\r\n".format(val=obj) + return f":{obj}\r\n" if isinstance(obj, (list, tuple)): items = "".join(self.encode(a) for a in obj) - return "*{len}\r\n{items}".format(len=len(obj), items=items) + return f"*{len(obj)}\r\n{items}" raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 76e6644353..b42f1288eb 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -70,7 +70,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, FakeSite(resource), "GET", - "/{}/{}".format(target, media_id), + f"/{target}/{media_id}", shorthand=False, access_token=self.access_token, await_result=False, @@ -113,7 +113,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5eca5c165d..f3615af97e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -211,7 +211,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) @@ -241,7 +241,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(vector_clock_token), + f"/sync?since={vector_clock_token}", access_token=access_token, ) @@ -266,7 +266,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 2f7090e554..a7c6e595b9 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Create a new group channel = self.make_request( "POST", - "/create_group".encode("ascii"), + b"/create_group", access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request( - "GET", "/joined_groups".encode("ascii"), access_token=access_token - ) + channel = self.make_request("GET", b"/joined_groups", access_token=access_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 959d3cea77..17ec8bfd3b 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") def _assert_peek(self, room_id, expect_code): """Assert that the admin user can (or cannot) peek into the room.""" @@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") class RoomTestCase(unittest.HomeserverTestCase): @@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.public_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" def test_requester_is_no_admin(self): """ @@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Join user to room. - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e94566ffd7..3df070c936 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/join", content={"reason": reason}, access_token=self.second_tok, ) @@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) @@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/kick", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) @@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/ban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/unban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/invite", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 856aa8682f..2e2f94742e 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index 1ec6b05e5b..a76a6fef1e 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) resp = self.helper.send(self.room_id, tok=self.admin_user_tok) self.event_id = resp["event_id"] - self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" def test_reason_str_and_score_int(self): data = {"reason": "this makes me sad", "score": -100} diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 95e7075841..2d6b49692e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote("\u2603".encode("utf8")).encode("ascii") + filename = parse.quote("\u2603".encode()).encode("ascii") channel = self._req( b"inline; filename*=utf-8''" + filename + self.test_image.extension ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 41bef62ca8..43628ce44f 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -59,5 +59,5 @@ class DirectoryStoreTestCase(HomeserverTestCase): self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (self.get_success(self.store.get_association_from_room_alias(self.alias))) + self.get_success(self.store.get_association_from_room_alias(self.alias)) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 8a446da848..a1ba99ff14 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -45,11 +45,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_displayname(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) def test_avatar_url(self): @@ -76,9 +72,5 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_avatar_url(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) ) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 54c5b470c7..e5574063f1 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -75,7 +75,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format(token.topological + 1, token.stream + 1) + event = f"t{token.topological + 1}-{token.stream + 1}" # Purge everything before this topological token f = self.get_failure( diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 70257bf210..31ce7f6252 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -49,7 +49,7 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room_unknown_room(self): - self.assertIsNone((self.get_success(self.store.get_room("!uknown:test")))) + self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) def test_get_room_with_stats(self): self.assertDictContainsSubset( diff --git a/tests/test_types.py b/tests/test_types.py index d7881021d3..0d0c00d97a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -103,6 +103,4 @@ class MapUsernameTestCase(unittest.TestCase): def testNonAscii(self): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("tĂȘst"), "t=c3=aast") - self.assertEqual( - map_username_to_mxid_localpart("tĂȘst".encode("utf-8")), "t=c3=aast" - ) + self.assertEqual(map_username_to_mxid_localpart("tĂȘst".encode()), "t=c3=aast") diff --git a/tests/unittest.py b/tests/unittest.py index 74db7c08f1..907b94b10a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -140,7 +140,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))("Assert error for '.{}':".format(key)) from e + raise (type(e))(f"Assert error for '.{key}':") from e def assert_dict(self, required, actual): """Does a partial assert of a dict. -- cgit 1.5.1 From 93729719b8451493e1df9930feb9f02f14ea5cef Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 12:52:58 +0200 Subject: Use inline type hints in `tests/` (#10350) This PR is tantamount to running: python3.8 -m com2ann -v 6 tests/ (com2ann requires python 3.8 to run) --- changelog.d/10350.misc | 1 + tests/events/test_presence_router.py | 6 +++--- tests/module_api/test_api.py | 16 ++++++++-------- tests/replication/_base.py | 12 ++++++------ tests/replication/tcp/streams/test_events.py | 14 +++++++------- tests/replication/tcp/streams/test_receipts.py | 4 ++-- tests/replication/tcp/streams/test_typing.py | 4 ++-- tests/replication/test_multi_media_repo.py | 2 +- tests/rest/client/test_third_party_rules.py | 4 ++-- tests/rest/client/v1/test_login.py | 14 ++++++-------- tests/server.py | 8 +++++--- tests/storage/test_background_update.py | 4 +--- tests/storage/test_id_generators.py | 6 +++--- tests/test_state.py | 2 +- tests/test_utils/html_parsers.py | 6 +++--- tests/unittest.py | 2 +- tests/util/caches/test_descriptors.py | 2 +- tests/util/test_itertools.py | 18 +++++++++--------- 18 files changed, 62 insertions(+), 63 deletions(-) create mode 100644 changelog.d/10350.misc (limited to 'tests/rest') diff --git a/changelog.d/10350.misc b/changelog.d/10350.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10350.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 875b0d0a11..c4ad33194d 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -152,7 +152,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_one_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "boop") @@ -274,7 +274,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence_updates, _ = sync_presence(self, self.other_user_id) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "I'm online!") @@ -320,7 +320,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 2c68b9a13c..81d9e2f484 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -100,9 +100,9 @@ class ModuleApiTestCase(HomeserverTestCase): "content": content, "sender": user_id, } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.message") self.assertEqual(event.room_id, room_id) @@ -136,9 +136,9 @@ class ModuleApiTestCase(HomeserverTestCase): "sender": user_id, "state_key": "", } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.power_levels") self.assertEqual(event.room_id, room_id) @@ -281,7 +281,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] @@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 386ea70a25..e9fd991718 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() - self.server = server_factory.buildProtocol( + self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( None - ) # type: ServerReplicationStreamProtocol + ) # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" @@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): fetching updates for given stream. """ - path = request.path # type: bytes # type: ignore + path: bytes = request.path # type: ignore self.assertRegex( path, br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" @@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + servlets: List[Callable[[HomeServer, JsonResource], None]] = [] def setUp(self): super().setUp() @@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler): super().__init__(hs) # list of received (stream_name, token, row) tuples - self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] + self.received_rdata_rows: List[Tuple[str, int, Any]] = [] async def on_rdata(self, stream_name, instance_name, token, rows): await super().on_rdata(stream_name, instance_name, token, rows) @@ -484,7 +484,7 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" - transport = None # type: Optional[FakeTransport] + transport: Optional[FakeTransport] = None def __init__(self, server: FakeRedisPubSubServer): self._server = server diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f51fa0a79e..666008425a 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) events = [ self._inject_state_event(sender=OTHER_USER) @@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual(row.data.event_id, pl_event.event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for stream_name, _, row in received_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) - events = [] # type: List[EventBase] + events: List[EventBase] = [] for user in user_ids: events.extend( self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) @@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual(row.data.event_id, pl_events[i].event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for _ in range(STATES_PER_USER + 1): stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 7f5d932f0b..38e292c1ab 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) @@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room2:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index ecd360c2d0..3ff5afc6e5 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) @@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index b42f1288eb..ffa425328f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request logger = logging.getLogger(__name__) -test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory] +test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e1fe72fc5d..c5e1c5458b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -233,11 +233,11 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "content": content, "sender": self.user_id, } - event = self.get_success( + event: EventBase = self.get_success( current_rules_module().module_api.create_and_send_event_into_room( event_dict ) - ) # type: EventBase + ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 605b952316..7eba69642a 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # stick the flows results in a dict by type - flow_results = {} # type: Dict[str, Any] + flow_results: Dict[str, Any] = {} for f in channel.json_body["flows"]: flow_type = f["type"] self.assertNotIn( @@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.close() # there should be a link for each href - returned_idps = [] # type: List[str] + returned_idps: List[str] = [] for link in p.links: path, query = link.split("?", 1) self.assertEqual(path, "pick_idp") @@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") assert cookie_headers - cookies = {} # type: Dict[str, str] + cookies: Dict[str, str] = {} for h in cookie_headers: key, value = h.split(";")[0].split("=", maxsplit=1) cookies[key] = value @@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode( - payload, secret, self.jwt_algorithm - ) # type: Union[str, bytes] + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) if isinstance(result, bytes): return result.decode("ascii") return result @@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") if isinstance(result, bytes): return result.decode("ascii") return result @@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie - cookies = {} # type: Dict[str,str] + cookies: Dict[str, str] = {} channel.extract_cookies(cookies) self.assertIn("username_mapping_session", cookies) session_id = cookies["username_mapping_session"] diff --git a/tests/server.py b/tests/server.py index f32d8dc375..6fddd3b305 100644 --- a/tests/server.py +++ b/tests/server.py @@ -52,7 +52,7 @@ class FakeChannel: _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) _ip = attr.ib(type=str, default="127.0.0.1") - _producer = None # type: Optional[Union[IPullProducer, IPushProducer]] + _producer: Optional[Union[IPullProducer, IPushProducer]] = None @property def json_body(self): @@ -316,8 +316,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] - lookups = self.lookups = {} # type: Dict[str, str] - self._thread_callbacks = deque() # type: Deque[Callable[[], None]] + self.lookups: Dict[str, str] = {} + self._thread_callbacks: Deque[Callable[[], None]] = deque() + + lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 069db0edc4..0da42b5ac5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -7,9 +7,7 @@ from tests import unittest class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = ( - self.hs.get_datastore().db_pool.updates - ) # type: BackgroundUpdater + self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 792b1c44c1..7486078284 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/test_state.py b/tests/test_state.py index 62f7095873..780eba823c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): self.store.register_events(graph.walk()) - context_store = {} # type: dict[str, EventContext] + context_store: dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index 1fbb38f4be..e878af5f12 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -23,13 +23,13 @@ class TestHtmlParser(HTMLParser): super().__init__() # a list of links found in the doc - self.links = [] # type: List[str] + self.links: List[str] = [] # the values of any hidden s: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] + self.hiddens: Dict[str, Optional[str]] = {} # the values of any radio buttons: map from name to list of values - self.radios = {} # type: Dict[str, List[Optional[str]]] + self.radios: Dict[str, List[Optional[str]]] = {} def handle_starttag( self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] diff --git a/tests/unittest.py b/tests/unittest.py index 907b94b10a..c6d9064423 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase): if not isinstance(deferred, Deferred): return d - results = [] # type: list + results: list = [] deferred.addBoth(results.append) self.pump(by=by) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 0277998cbe..39947a166b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -174,7 +174,7 @@ class DescriptorTestCase(unittest.TestCase): return self.result obj = Cls() - callbacks = set() # type: Set[str] + callbacks: Set[str] = set() # set off an asynchronous request obj.result = origin_d = defer.Deferred() diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index e712eb42ea..3c0ddd4f18 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase): ) def test_empty_input(self): - parts = chunk_seq([], 5) # type: Iterable[Sequence] + parts: Iterable[Sequence] = chunk_seq([], 5) self.assertEqual( list(parts), @@ -56,13 +56,13 @@ class SortTopologically(TestCase): def test_empty(self): "Test that an empty graph works correctly" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} self.assertEqual(list(sorted_topologically([], graph)), []) def test_handle_empty_graph(self): "Test that a graph where a node doesn't have an entry is treated as empty" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -70,7 +70,7 @@ class SortTopologically(TestCase): def test_disconnected(self): "Test that a graph with no edges work" - graph = {1: [], 2: []} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: []} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -78,19 +78,19 @@ class SortTopologically(TestCase): def test_linear(self): "Test that a simple `4 -> 3 -> 2 -> 1` graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_subset(self): "Test that only sorting a subset of the graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) def test_fork(self): "Test that a forked graph works" - graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # always get the same one. @@ -98,12 +98,12 @@ class SortTopologically(TestCase): def test_duplicates(self): "Test that a graph with duplicate edges work" - graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_multiple_paths(self): "Test that a graph with multiple paths between two nodes work" - graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From 4e340412c020f685cb402a735b983f6e332e206b Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 19 Jul 2021 16:11:34 +0100 Subject: Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric (#10332) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/10332.feature | 1 + synapse/app/phone_stats_home.py | 4 + synapse/storage/databases/main/metrics.py | 129 ++++++++++++++++ tests/app/test_phone_stats_home.py | 242 ++++++++++++++++++++++++++++++ tests/rest/client/v1/utils.py | 30 +++- tests/unittest.py | 15 +- 6 files changed, 416 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10332.feature (limited to 'tests/rest') diff --git a/changelog.d/10332.feature b/changelog.d/10332.feature new file mode 100644 index 0000000000..091947ff22 --- /dev/null +++ b/changelog.d/10332.feature @@ -0,0 +1 @@ +Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 8f86cecb76..7904c246df 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -107,6 +107,10 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): for name, count in r30_results.items(): stats["r30_users_" + name] = count + r30v2_results = await hs.get_datastore().count_r30_users() + for name, count in r30v2_results.items(): + stats["r30v2_users_" + name] = count + stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index e3a544d9b2..dc0bbc56ac 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -316,6 +316,135 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + async def count_r30v2_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as users that: + - Appear more than once in the past 60 days + - Have more than 30 days between the most and least recent appearances that + occurred in the past 60 days. + + (This is the second version of this metric, hence R30'v2') + + Returns: + A mapping from client type to the number of 30-day retained users for that client. + + The dict keys are: + - "all" (a combined number of users across any and all clients) + - "android" (Element Android) + - "ios" (Element iOS) + - "electron" (Element Desktop) + - "web" (any web application -- it's not possible to distinguish Element Web here) + """ + + def _count_r30v2_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs + one_day_from_now_in_secs = now + 86400 + + # This is the 'per-platform' count. + sql = """ + SELECT + client_type, + count(client_type) + FROM + ( + SELECT + user_id, + CASE + WHEN + LOWER(user_agent) LIKE '%%riot%%' OR + LOWER(user_agent) LIKE '%%element%%' + THEN CASE + WHEN + LOWER(user_agent) LIKE '%%electron%%' + THEN 'electron' + WHEN + LOWER(user_agent) LIKE '%%android%%' + THEN 'android' + WHEN + LOWER(user_agent) LIKE '%%ios%%' + THEN 'ios' + ELSE 'unknown' + END + WHEN + LOWER(user_agent) LIKE '%%mozilla%%' OR + LOWER(user_agent) LIKE '%%gecko%%' + THEN 'web' + ELSE 'unknown' + END as client_type + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id, + client_type + HAVING + max(timestamp) - min(timestamp) > ? + ) AS temp + GROUP BY + client_type + ; + """ + + # We initialise all the client types to zero, so we get an explicit + # zero if they don't appear in the query results + results = {"ios": 0, "android": 0, "web": 0, "electron": 0} + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + + for row in txn: + if row[0] == "unknown": + continue + results[row[0]] = row[1] + + # This is the 'all users' count. + sql = """ + SELECT COUNT(*) FROM ( + SELECT + 1 + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id + HAVING + max(timestamp) - min(timestamp) > ? + ) AS r30_users + """ + + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + row = txn.fetchone() + if row is None: + results["all"] = 0 + else: + results["all"] = row[0] + + return results + + return await self.db_pool.runInteraction( + "count_r30v2_users", _count_r30v2_users + ) + def _get_start_of_day(self): """ Returns millisecond unixtime for start of UTC day. diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 2da6ba4dde..5527e278db 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,9 +1,11 @@ import synapse +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.rest.client.v1 import login, room from tests import unittest from tests.unittest import HomeserverTestCase +FIVE_MINUTES_IN_SECONDS = 300 ONE_DAY_IN_SECONDS = 86400 @@ -151,3 +153,243 @@ class PhoneHomeTestCase(HomeserverTestCase): # *Now* the user appears in R30. r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + +class PhoneHomeR30V2TestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def _advance_to(self, desired_time_secs: float): + now = self.hs.get_clock().time() + assert now < desired_time_secs + self.reactor.advance(desired_time_secs - now) + + def make_homeserver(self, reactor, clock): + hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) + + # We don't want our tests to actually report statistics, so check + # that it's not enabled + assert not hs.config.report_stats + + # This starts the needed data collection that we rely on to calculate + # R30v2 metrics. + start_phone_stats_home(hs) + return hs + + def test_r30v2_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30v2 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + first_post_at = self.hs.get_clock().time() + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the R30 results do not count that user. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance 31 days. + # (R30v2 includes users with **more** than 30 days between the two visits, + # and user_daily_visits records the timestamp as the start of the day.) + self.reactor.advance(31 * ONE_DAY_IN_SECONDS) + # Also advance 5 minutes to let another user_daily_visits update occur + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait a few minutes for the user_daily_visits table to + # be updated by a background process. + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user is counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance to JUST under 60 days after the user's first post + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5) + + # Check the user is still counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance into the next day. The user's first activity is now more than 60 days old. + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5) + + # Check the user is now no longer counted in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30v2 statistic, even if they post every day + during that time! + """ + + # set a custom user-agent to impersonate Element/Android. + headers = ( + ( + "User-Agent", + "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS) + self.helper.send( + room_id, "I'm still here", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # advance yet another day with more activity + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send( + room_id, "Still here!", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user appears in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_returning_dormant_users_not_counted(self): + """ + Tests that dormant users (users inactive for a long time) do not + contribute to R30v2 when they return for just a single day. + This is a key difference between R30 and R30v2. + """ + + # set a custom user-agent to impersonate Element/iOS. + headers = ( + ( + "User-Agent", + "Riot/1.4 (iPhone; iOS 13; Scale/4.00)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # the user goes inactive for 2 months + self.reactor.advance(60 * ONE_DAY_IN_SECONDS) + + # the user returns for one day, perhaps just to check out a new feature + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check that the user does not contribute to R30v2, even though it's been + # more than 30 days since registration. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Check that this is a situation where old R30 differs: + # old R30 DOES count this as 'retained'. + r30_results = self.get_success(store.count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "ios": 1}) + + # Now we want to check that the user will still be able to appear in + # R30v2 as long as the user performs some other activity between + # 30 and 60 days later. + self.reactor.advance(32 * ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # (give time for tables to update) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Check the user now satisfies the requirements to appear in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 59.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS) + + # Check the user still appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 60.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # Check the user no longer appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 69798e95c3..fc2d35596e 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -19,7 +19,7 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Mapping, MutableMapping, Optional +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from unittest.mock import patch import attr @@ -53,6 +53,9 @@ class RestHelper: tok: str = None, expect_code: int = 200, extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> str: """ Create a room. @@ -87,6 +90,7 @@ class RestHelper: "POST", path, json.dumps(content).encode("utf8"), + custom_headers=custom_headers, ) assert channel.result["code"] == b"%d" % expect_code, channel.result @@ -175,14 +179,30 @@ class RestHelper: self.auth_user_id = temp_id - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): if body is None: body = "body_text_here" content = {"msgtype": "m.text", "body": body} return self.send_event( - room_id, "m.room.message", content, txn_id, tok, expect_code + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, ) def send_event( @@ -193,6 +213,9 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -207,6 +230,7 @@ class RestHelper: "PUT", path, json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, ) assert ( diff --git a/tests/unittest.py b/tests/unittest.py index c6d9064423..3eec9c4d5b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase): user_id = channel.json_body["user_id"] return user_id - def login(self, username, password, device_id=None): + def login( + self, + username, + password, + device_id=None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): """ Log in a user, and get an access token. Requires the Login API be registered. @@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase): body["device_id"] = device_id channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") + "POST", + "/_matrix/client/r0/login", + json.dumps(body).encode("utf8"), + custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1 From a743bf46949e851c9a10d8e01a138659f3af2484 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 20 Jul 2021 12:39:46 +0200 Subject: Port the ThirdPartyEventRules module interface to the new generic interface (#10386) Port the third-party event rules interface to the generic module interface introduced in v1.37.0 --- changelog.d/10386.removal | 1 + docs/modules.md | 62 ++++++- docs/sample_config.yaml | 13 -- docs/upgrade.md | 13 ++ synapse/app/_base.py | 2 + synapse/config/third_party_event_rules.py | 15 -- synapse/events/third_party_rules.py | 245 +++++++++++++++++++++++----- synapse/handlers/federation.py | 4 +- synapse/handlers/message.py | 8 +- synapse/handlers/room.py | 10 +- synapse/module_api/__init__.py | 6 + tests/rest/client/test_third_party_rules.py | 132 ++++++++++++--- 12 files changed, 403 insertions(+), 108 deletions(-) create mode 100644 changelog.d/10386.removal (limited to 'tests/rest') diff --git a/changelog.d/10386.removal b/changelog.d/10386.removal new file mode 100644 index 0000000000..800a6143d7 --- /dev/null +++ b/changelog.d/10386.removal @@ -0,0 +1 @@ +The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information. diff --git a/docs/modules.md b/docs/modules.md index c4cb7018f7..9a430390a4 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -186,7 +186,7 @@ The arguments passed to this callback are: ```python async def check_media_file_for_spam( file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper", - file_info: "synapse.rest.media.v1._base.FileInfo" + file_info: "synapse.rest.media.v1._base.FileInfo", ) -> bool ``` @@ -223,6 +223,66 @@ Called after successfully registering a user, in case the module needs to perfor operations to keep track of them. (e.g. add them to a database table). The user is represented by their Matrix user ID. +#### Third party rules callbacks + +Third party rules callbacks allow module developers to add extra checks to verify the +validity of incoming events. Third party event rules callbacks can be registered using +the module API's `register_third_party_rules_callbacks` method. + +The available third party rules callbacks are: + +```python +async def check_event_allowed( + event: "synapse.events.EventBase", + state_events: "synapse.types.StateMap", +) -> Tuple[bool, Optional[dict]] +``` + +** +This callback is very experimental and can and will break without notice. Module developers +are encouraged to implement `check_event_for_spam` from the spam checker category instead. +** + +Called when processing any incoming event, with the event and a `StateMap` +representing the current state of the room the event is being sent into. A `StateMap` is +a dictionary that maps tuples containing an event type and a state key to the +corresponding state event. For example retrieving the room's `m.room.create` event from +the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`. +The module must return a boolean indicating whether the event can be allowed. + +Note that this callback function processes incoming events coming via federation +traffic (on top of client traffic). This means denying an event might cause the local +copy of the room's history to diverge from that of remote servers. This may cause +federation issues in the room. It is strongly recommended to only deny events using this +callback function if the sender is a local user, or in a private federation in which all +servers are using the same module, with the same configuration. + +If the boolean returned by the module is `True`, it may also tell Synapse to replace the +event with new data by returning the new event's data as a dictionary. In order to do +that, it is recommended the module calls `event.get_dict()` to get the current event as a +dictionary, and modify the returned dictionary accordingly. + +Note that replacing the event only works for events sent by local users, not for events +received over federation. + +```python +async def on_create_room( + requester: "synapse.types.Requester", + request_content: dict, + is_requester_admin: bool, +) -> None +``` + +Called when processing a room creation request, with the `Requester` object for the user +performing the request, a dictionary representing the room creation request's JSON body +(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom) +for a list of possible parameters), and a boolean indicating whether the user performing +the request is a server admin. + +Modules can modify the `request_content` (by e.g. adding events to its `initial_state`), +or deny the room's creation by raising a `module_api.errors.SynapseError`. + + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f4845a5841..853c2f6899 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2654,19 +2654,6 @@ stats: # action: allow -# Server admins can define a Python module that implements extra rules for -# allowing or denying incoming events. In order to work, this module needs to -# override the methods defined in synapse/events/third_party_rules.py. -# -# This feature is designed to be used in closed federations only, where each -# participating server enforces the same rules. -# -#third_party_event_rules: -# module: "my_custom_project.SuperRulesSet" -# config: -# example_option: 'things' - - ## Opentracing ## # These settings enable opentracing, which implements distributed tracing. diff --git a/docs/upgrade.md b/docs/upgrade.md index db0450f563..c8f4a2c171 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,19 @@ process, for example: ``` +# Upgrading to v1.39.0 + +## Deprecation of the current third-party rules module interface + +The current third-party rules module interface is deprecated in favour of the new generic +modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer +to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface) +to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules) +to update their configuration once the modules they are using have been updated. + +We plan to remove support for the current third-party rules interface in September 2021. + + # Upgrading to v1.38.0 ## Re-indexing of `events` table on Postgres databases diff --git a/synapse/app/_base.py b/synapse/app/_base.py index b30571fe49..50a02f51f5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -368,6 +369,7 @@ async def start(hs: "HomeServer"): module(config=config, api=module_api) load_legacy_spam_checkers(hs) + load_legacy_third_party_event_rules(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index f502ff539e..a3fae02420 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config): self.third_party_event_rules = load_module( provider, ("third_party_event_rules",) ) - - def generate_config_section(self, **kwargs): - return """\ - # Server admins can define a Python module that implements extra rules for - # allowing or denying incoming events. In order to work, this module needs to - # override the methods defined in synapse/events/third_party_rules.py. - # - # This feature is designed to be used in closed federations only, where each - # participating server enforces the same rules. - # - #third_party_event_rules: - # module: "my_custom_project.SuperRulesSet" - # config: - # example_option: 'things' - """ diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index f7944fd834..7a6eb3e516 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -11,16 +11,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from typing import TYPE_CHECKING, Union - +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Requester, StateMap +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + + +CHECK_EVENT_ALLOWED_CALLBACK = Callable[ + [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] +] +ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] +CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ + [str, str, StateMap[EventBase]], Awaitable[bool] +] +CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ + [str, StateMap[EventBase], str], Awaitable[bool] +] + + +def load_legacy_third_party_event_rules(hs: "HomeServer"): + """Wrapper that loads a third party event rules module configured using the old + configuration, and registers the hooks they implement. + """ + if hs.config.third_party_event_rules is None: + return + + module, config = hs.config.third_party_event_rules + + api = hs.get_module_api() + third_party_rules = module(config=config, module_api=api) + + # The known hooks. If a module implements a method which name appears in this set, + # we'll want to register it. + third_party_event_rules_methods = { + "check_event_allowed", + "on_create_room", + "check_threepid_can_be_invited", + "check_visibility_can_be_modified", + } + + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We return a separate wrapper for these methods because, in order to wrap them + # correctly, we need to await its result. Therefore it doesn't make a lot of + # sense to make it go through the run() wrapper. + if f.__name__ == "check_event_allowed": + + # We need to wrap check_event_allowed because its old form would return either + # a boolean or a dict, but now we want to return the dict separately from the + # boolean. + async def wrap_check_event_allowed( + event: EventBase, + state_events: StateMap[EventBase], + ) -> Tuple[bool, Optional[dict]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(event, state_events) + if isinstance(res, dict): + return True, res + else: + return res, None + + return wrap_check_event_allowed + + if f.__name__ == "on_create_room": + + # We need to wrap on_create_room because its old form would return a boolean + # if the room creation is denied, but now we just want it to raise an + # exception. + async def wrap_on_create_room( + requester: Requester, config: dict, is_requester_admin: bool + ) -> None: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(requester, config, is_requester_admin) + if res is False: + raise SynapseError( + 403, + "Room creation forbidden with these parameters", + ) + + return wrap_on_create_room + + def run(*args, **kwargs): + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None + + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(third_party_rules, hook, None)) + for hook in third_party_event_rules_methods + } + + api.register_third_party_rules_callbacks(**hooks) + class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra @@ -35,36 +143,65 @@ class ThirdPartyEventRules: self.store = hs.get_datastore() - module = None - config = None - if hs.config.third_party_event_rules: - module, config = hs.config.third_party_event_rules + self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] + self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] + self._check_threepid_can_be_invited_callbacks: List[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = [] + self._check_visibility_can_be_modified_callbacks: List[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = [] + + def register_third_party_rules_callbacks( + self, + check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, + check_threepid_can_be_invited: Optional[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = None, + check_visibility_can_be_modified: Optional[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = None, + ): + """Register callbacks from modules for each hook.""" + if check_event_allowed is not None: + self._check_event_allowed_callbacks.append(check_event_allowed) + + if on_create_room is not None: + self._on_create_room_callbacks.append(on_create_room) + + if check_threepid_can_be_invited is not None: + self._check_threepid_can_be_invited_callbacks.append( + check_threepid_can_be_invited, + ) - if module is not None: - self.third_party_rules = module( - config=config, - module_api=hs.get_module_api(), + if check_visibility_can_be_modified is not None: + self._check_visibility_can_be_modified_callbacks.append( + check_visibility_can_be_modified, ) async def check_event_allowed( self, event: EventBase, context: EventContext - ) -> Union[bool, dict]: + ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. The module can return: * True: the event is allowed. * False: the event is not allowed, and should be rejected with M_FORBIDDEN. - * a dict: replacement event data. + + If the event is allowed, the module can also return a dictionary to use as a + replacement for the event. Args: event: The event to be checked. context: The context of the event. Returns: - The result from the ThirdPartyRules module, as above + The result from the ThirdPartyRules module, as above. """ - if self.third_party_rules is None: - return True + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_event_allowed_callbacks) == 0: + return True, None prev_state_ids = await context.get_prev_state_ids() @@ -77,29 +214,46 @@ class ThirdPartyEventRules: # the hashes and signatures. event.freeze() - return await self.third_party_rules.check_event_allowed(event, state_events) + for callback in self._check_event_allowed_callbacks: + try: + res, replacement_data = await callback(event, state_events) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue + + # Return if the event shouldn't be allowed or if the module came up with a + # replacement dict for the event. + if res is False: + return res, None + elif isinstance(replacement_data, dict): + return True, replacement_data + + return True, None async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ) -> bool: - """Intercept requests to create room to allow, deny or update the - request config. + ) -> None: + """Intercept requests to create room to maybe deny it (via an exception) or + update the request config. Args: requester config: The creation config from the client. is_requester_admin: If the requester is an admin - - Returns: - Whether room creation is allowed or denied. """ - - if self.third_party_rules is None: - return True - - return await self.third_party_rules.on_create_room( - requester, config, is_requester_admin - ) + for callback in self._on_create_room_callbacks: + try: + await callback(requester, config, is_requester_admin) + except Exception as e: + # Don't silence the errors raised by this callback since we expect it to + # raise an exception to deny the creation of the room; instead make sure + # it's a SynapseError we can send to clients. + if not isinstance(e, SynapseError): + e = SynapseError( + 403, "Room creation forbidden with these parameters" + ) + + raise e async def check_threepid_can_be_invited( self, medium: str, address: str, room_id: str @@ -114,15 +268,20 @@ class ThirdPartyEventRules: Returns: True if the 3PID can be invited, False if not. """ - - if self.third_party_rules is None: + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_threepid_can_be_invited_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) + for callback in self._check_threepid_can_be_invited_callbacks: + try: + if await callback(medium, address, state_events) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def check_visibility_can_be_modified( self, room_id: str, new_visibility: str @@ -137,18 +296,20 @@ class ThirdPartyEventRules: Returns: True if the room's visibility can be modified, False if not. """ - if self.third_party_rules is None: - return True - - check_func = getattr( - self.third_party_rules, "check_visibility_can_be_modified", None - ) - if not check_func or not callable(check_func): + # Bail out early without hitting the store if we don't have any callback + if len(self._check_visibility_can_be_modified_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await check_func(room_id, state_events, new_visibility) + for callback in self._check_visibility_can_be_modified_callbacks: + try: + if await callback(room_id, state_events, new_visibility) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index cf389be3e4..5728719909 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler): # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c7fe4ff89e..8a0024ce84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -949,10 +949,10 @@ class EventCreationHandler: if requester: context.app_service = requester.app_service - third_party_result = await self.third_party_event_rules.check_event_allowed( + res, new_content = await self.third_party_event_rules.check_event_allowed( event, context ) - if not third_party_result: + if res is False: logger.info( "Event %s forbidden by third-party rules", event, @@ -960,11 +960,11 @@ class EventCreationHandler: raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - elif isinstance(third_party_result, dict): + elif new_content is not None: # the third-party rules want to replace the event. We'll need to build a new # event. event, context = await self._rebuild_event_after_third_party_rules( - third_party_result, event + new_content, event ) self.validator.validate_new(event, self.config) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 64656fda22..370561e549 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler): else: is_requester_admin = await self.auth.is_server_admin(requester.user) - # Check whether the third party rules allows/changes the room create - # request. - event_allowed = await self.third_party_event_rules.on_create_room( + # Let the third party rules modify the room creation config if needed, or abort + # the room creation entirely with an exception. + await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) - if not event_allowed: - raise SynapseError( - 403, "You are not permitted to create rooms", Codes.FORBIDDEN - ) if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 5df9349134..1259fc2d90 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -110,6 +110,7 @@ class ModuleApi: self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() + self._third_party_event_rules = hs.get_third_party_event_rules() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -124,6 +125,11 @@ class ModuleApi: """Registers callbacks for account validity capabilities.""" return self._account_validity_handler.register_account_validity_callbacks + @property + def register_third_party_rules_callbacks(self): + """Registers callbacks for third party event rules capabilities.""" + return self._third_party_event_rules.register_third_party_rules_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index c5e1c5458b..28dd47a28b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -16,17 +16,19 @@ from typing import Dict from unittest.mock import Mock from synapse.events import EventBase +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import Requester, StateMap +from synapse.util.frozenutils import unfreeze from tests import unittest thread_local = threading.local() -class ThirdPartyRulesTestModule: +class LegacyThirdPartyRulesTestModule: def __init__(self, config: Dict, module_api: ModuleApi): # keep a record of the "current" rules module, so that the test can patch # it if desired. @@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule: return config -def current_rules_module() -> ThirdPartyRulesTestModule: - return thread_local.rules_module +class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return False + + +class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + d = event.get_dict() + content = unfreeze(event.content) + content["foo"] = "bar" + d["content"] = content + return d class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): @@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def default_config(self): - config = super().default_config() - config["third_party_event_rules"] = { - "module": __name__ + ".ThirdPartyRulesTestModule", - "config": {}, - } - return config + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + load_legacy_third_party_event_rules(hs) + + return hs def prepare(self, reactor, clock, homeserver): # Create a user and room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.tok = self.login("kermit", "monkey") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + # Some tests might prevent room creation on purpose. + try: + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + except Exception: + pass def test_third_party_rules(self): """Tests that a forbidden event is forbidden from being sent, but an allowed one @@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # patch the rules module with a Mock which will return False for some event # types async def check(ev, state): - return ev.type != "foo.bar.forbidden" + return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) - current_rules_module().check_event_allowed = callback + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [ + callback + ] channel = self.make_request( "PUT", @@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # first patch the event checker so that it will try to modify the event async def check(ev: EventBase, state): ev.content = {"x": "y"} - return True + return True, None - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"500", channel.result) + # check_event_allowed has some error handling, so it shouldn't 500 just because a + # module did something bad. + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "x") def test_modify_event(self): """The module can return a modified version of the event""" @@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): async def check(ev: EventBase, state): d = ev.get_dict() d["content"] = {"x": "y"} - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "msgtype": "m.text", "body": d["content"]["body"].upper(), } - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # Send an event, then edit it. channel = self.make_request( @@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEqual(ev["content"]["body"], "EDITED BODY") def test_send_event(self): - """Tests that the module can send an event into a room via the module api""" + """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", "body": "Hello!", @@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "sender": self.user_id, } event: EventBase = self.get_success( - current_rules_module().module_api.create_and_send_event_into_room( - event_dict - ) + self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) self.assertEquals(event.type, "m.room.message") self.assertEquals(event.content, content) + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyChangeEvents", + "config": {}, + } + } + ) + def test_legacy_check_event_allowed(self): + """Tests that the wrapper for legacy check_event_allowed callbacks works + correctly. + """ + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + self.assertIn("foo", channel.json_body["content"].keys()) + self.assertEqual(channel.json_body["content"]["foo"], "bar") + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyDenyNewRooms", + "config": {}, + } + } + ) + def test_legacy_on_create_room(self): + """Tests that the wrapper for legacy on_create_room callbacks works + correctly. + """ + self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) -- cgit 1.5.1