From 6c02cca95f8136010062b6af0fa36a2906a96a6b Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 1 Jul 2021 11:26:24 +0200 Subject: Add SSO `external_ids` to Query User Account admin API (#10261) Related to #10251 --- tests/rest/admin/test_user.py | 224 ++++++++++++++++++++++++++---------------- 1 file changed, 140 insertions(+), 84 deletions(-) (limited to 'tests') diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index d599a4c984..a34d051734 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1150,7 +1150,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1160,7 +1160,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): @@ -1177,6 +1177,58 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) + self._check_fields(channel.json_body) + + def test_get_user_with_sso(self): + """ + Test get a user with SSO details. + """ + self.get_success( + self.store.record_user_external_id( + "auth_provider1", "external_id1", self.other_user + ) + ) + self.get_success( + self.store.record_user_external_id( + "auth_provider2", "external_id2", self.other_user + ) + ) + + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][1]["external_id"] + ) + self.assertEqual( + "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + ) + self._check_fields(channel.json_body) + def test_create_server_admin(self): """ Check that a new admin user is created successfully. @@ -1184,30 +1236,29 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user (server admin) - body = json.dumps( - { - "password": "abc123", - "admin": True, - "displayname": "Bob's name", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": "mxc://fibble/wibble", - } - ) + body = { + "password": "abc123", + "admin": True, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1216,7 +1267,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1225,6 +1276,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["is_guest"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) def test_create_user(self): """ @@ -1233,30 +1285,29 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "admin": False, - "displayname": "Bob's name", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": "mxc://fibble/wibble", - } - ) + body = { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1265,7 +1316,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1275,6 +1326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1311,16 +1363,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123", "admin": False}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "admin": False}, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1350,17 +1400,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123", "admin": False}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "admin": False}, ) # Admin user is not blocked by mau anymore - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1382,21 +1430,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - ) + body = { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1426,21 +1472,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - ) + body = { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1457,16 +1501,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Change password - body = json.dumps({"password": "hahaha"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "hahaha"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self._check_fields(channel.json_body) def test_set_displayname(self): """ @@ -1474,16 +1517,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Modify user - body = json.dumps({"displayname": "foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "foobar"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1494,7 +1535,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1504,18 +1545,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Delete old and add new threepid to user - body = json.dumps( - {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} - ) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1527,7 +1564,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1552,7 +1589,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1567,7 +1604,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1583,7 +1620,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1610,7 +1647,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -1626,7 +1663,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -1650,7 +1687,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -1659,7 +1696,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -1681,7 +1718,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -1691,7 +1728,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1713,7 +1750,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -1723,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1742,7 +1779,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -1753,7 +1790,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -1772,7 +1809,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -1783,7 +1820,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -1796,7 +1833,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -1805,7 +1842,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -1830,7 +1867,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -1838,6 +1875,25 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIsNone(self.get_success(d)) self._is_erased(user_id, True) + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + + Args: + content: Content dictionary to check + """ + self.assertIn("displayname", content) + self.assertIn("threepids", content) + self.assertIn("avatar_url", content) + self.assertIn("admin", content) + self.assertIn("deactivated", content) + self.assertIn("shadow_banned", content) + self.assertIn("password_hash", content) + self.assertIn("creation_ts", content) + self.assertIn("appservice_id", content) + self.assertIn("consent_server_notice_sent", content) + self.assertIn("consent_version", content) + self.assertIn("external_ids", content) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 8d609435c0053fc4decbc3f9c3603e728912749c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 1 Jul 2021 14:25:37 -0400 Subject: Move methods involving event authentication to EventAuthHandler. (#10268) Instead of mixing them with user authentication methods. --- changelog.d/10268.misc | 1 + synapse/api/auth.py | 75 +------------------------------- synapse/events/builder.py | 12 ++--- synapse/federation/federation_server.py | 6 +-- synapse/handlers/event_auth.py | 62 +++++++++++++++++++++++++- synapse/handlers/federation.py | 36 +++++++++------ synapse/handlers/message.py | 9 ++-- synapse/handlers/room.py | 3 +- synapse/handlers/space_summary.py | 6 ++- synapse/push/bulk_push_rule_evaluator.py | 4 +- tests/handlers/test_presence.py | 4 +- 11 files changed, 112 insertions(+), 106 deletions(-) create mode 100644 changelog.d/10268.misc (limited to 'tests') diff --git a/changelog.d/10268.misc b/changelog.d/10268.misc new file mode 100644 index 0000000000..9e3f60c72f --- /dev/null +++ b/changelog.d/10268.misc @@ -0,0 +1 @@ +Move event authentication methods from `Auth` to `EventAuthHandler`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f8b068e563..307f5f9a94 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple import pymacaroons from netaddr import IPAddress @@ -28,10 +28,8 @@ from synapse.api.errors import ( InvalidClientTokenError, MissingClientTokenError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.appservice import ApplicationService from synapse.events import EventBase -from synapse.events.builder import EventBuilder from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing @@ -39,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry -from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,15 +44,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -AuthEventTypes = ( - EventTypes.Create, - EventTypes.Member, - EventTypes.PowerLevels, - EventTypes.JoinRules, - EventTypes.RoomHistoryVisibility, - EventTypes.ThirdPartyInvite, -) - # guests always get this device id. GUEST_DEVICE_ID = "guest_device" @@ -66,9 +54,7 @@ class _InvalidMacaroonException(Exception): class Auth: """ - FIXME: This class contains a mix of functions for authenticating users - of our client-server API and authenticating events added to room graphs. - The latter should be moved to synapse.handlers.event_auth.EventAuthHandler. + This class contains functions for authenticating users of our client-server API. """ def __init__(self, hs: "HomeServer"): @@ -90,18 +76,6 @@ class Auth: self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users - async def check_from_context( - self, room_version: str, event, context, do_sig_check=True - ) -> None: - auth_event_ids = event.auth_event_ids() - auth_events_by_id = await self.store.get_events(auth_event_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} - - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - event_auth.check( - room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check - ) - async def check_user_in_room( self, room_id: str, @@ -152,13 +126,6 @@ class Auth: raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - async def check_host_in_room(self, room_id: str, host: str) -> bool: - with Measure(self.clock, "check_host_in_room"): - return await self.store.is_host_joined(room_id, host) - - def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]: - return event_auth.get_public_keys(invite_event) - async def get_user_by_req( self, request: SynapseRequest, @@ -489,44 +456,6 @@ class Auth: """ return await self.store.is_server_admin(user) - def compute_auth_events( - self, - event: Union[EventBase, EventBuilder], - current_state_ids: StateMap[str], - for_verification: bool = False, - ) -> List[str]: - """Given an event and current state return the list of event IDs used - to auth an event. - - If `for_verification` is False then only return auth events that - should be added to the event's `auth_events`. - - Returns: - List of event IDs. - """ - - if event.type == EventTypes.Create: - return [] - - # Currently we ignore the `for_verification` flag even though there are - # some situations where we can drop particular auth events when adding - # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publicly joinable). Dropping event IDs has the - # advantage that the auth chain for the room grows slower, but we use - # the auth chain in state resolution v2 to order events, which means - # care must be taken if dropping events to ensure that it doesn't - # introduce undesirable "state reset" behaviour. - # - # All of which sounds a bit tricky so we don't bother for now. - - auth_ids = [] - for etype, state_key in event_auth.auth_types_for_event(event): - auth_ev_id = current_state_ids.get((etype, state_key)) - if auth_ev_id: - auth_ids.append(auth_ev_id) - - return auth_ids - async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: """Determine whether the user is allowed to edit the room's entry in the published room list. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index fb48ec8541..26e3950859 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -34,7 +34,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string if TYPE_CHECKING: - from synapse.api.auth import Auth + from synapse.handlers.event_auth import EventAuthHandler from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ class EventBuilder: """ _state: StateHandler - _auth: "Auth" + _event_auth_handler: "EventAuthHandler" _store: DataStore _clock: Clock _hostname: str @@ -125,7 +125,9 @@ class EventBuilder: state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_event_ids = self._auth.compute_auth_events(self, state_ids) + auth_event_ids = self._event_auth_handler.compute_auth_events( + self, state_ids + ) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: @@ -193,7 +195,7 @@ class EventBuilderFactory: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() def new(self, room_version: str, key_values: dict) -> EventBuilder: """Generate an event builder appropriate for the given room version @@ -229,7 +231,7 @@ class EventBuilderFactory: return EventBuilder( store=self.store, state=self.state, - auth=self.auth, + event_auth_handler=self._event_auth_handler, clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e93b7577fe..b312d0b809 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -108,9 +108,9 @@ class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.auth = hs.get_auth() self.handler = hs.get_federation_handler() self.state = hs.get_state_handler() + self._event_auth_handler = hs.get_event_auth_handler() self.device_handler = hs.get_device_handler() @@ -420,7 +420,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -453,7 +453,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 989996b628..41dbdfd0a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Collection, Optional +from typing import TYPE_CHECKING, Collection, List, Optional, Union +from synapse import event_auth from synapse.api.constants import ( EventTypes, JoinRules, @@ -20,9 +21,11 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, ) from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersion +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase +from synapse.events.builder import EventBuilder from synapse.types import StateMap +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,8 +37,63 @@ class EventAuthHandler: """ def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastore() + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ) -> None: + auth_event_ids = event.auth_event_ids() + auth_events_by_id = await self._store.get_events(auth_event_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} + + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + event_auth.check( + room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check + ) + + def compute_auth_events( + self, + event: Union[EventBase, EventBuilder], + current_state_ids: StateMap[str], + for_verification: bool = False, + ) -> List[str]: + """Given an event and current state return the list of event IDs used + to auth an event. + + If `for_verification` is False then only return auth events that + should be added to the event's `auth_events`. + + Returns: + List of event IDs. + """ + + if event.type == EventTypes.Create: + return [] + + # Currently we ignore the `for_verification` flag even though there are + # some situations where we can drop particular auth events when adding + # to the event's `auth_events` (e.g. joins pointing to previous joins + # when room is publicly joinable). Dropping event IDs has the + # advantage that the auth chain for the room grows slower, but we use + # the auth chain in state resolution v2 to order events, which means + # care must be taken if dropping events to ensure that it doesn't + # introduce undesirable "state reset" behaviour. + # + # All of which sounds a bit tricky so we don't bother for now. + + auth_ids = [] + for etype, state_key in event_auth.auth_types_for_event(event): + auth_ev_id = current_state_ids.get((etype, state_key)) + if auth_ev_id: + auth_ids.append(auth_ev_id) + + return auth_ids + + async def check_host_in_room(self, room_id: str, host: str) -> bool: + with Measure(self._clock, "check_host_in_room"): + return await self._store.is_host_joined(room_id, host) + async def check_restricted_join_rules( self, state_ids: StateMap[str], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d929c65131..991ec9919a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -250,7 +250,9 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) if not is_in_room: logger.info( "Ignoring PDU from %s as we're not in the room", @@ -1674,7 +1676,9 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) # now check that we are *still* in the room - is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) if not is_in_room: logger.info( "Got /make_join request for room %s we are no longer in", @@ -1705,7 +1709,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) @@ -1877,7 +1881,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: @@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: @@ -2111,7 +2115,7 @@ class FederationHandler(BaseHandler): async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2146,7 +2150,9 @@ class FederationHandler(BaseHandler): ) if event: - in_room = await self.auth.check_host_in_room(event.room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room( + event.room_id, origin + ) if not in_room: raise AuthError(403, "Host not in room.") @@ -2499,7 +2505,7 @@ class FederationHandler(BaseHandler): latest_events: List[str], limit: int, ) -> List[EventBase]: - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2562,7 +2568,7 @@ class FederationHandler(BaseHandler): if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) @@ -2991,7 +2997,7 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if await self.auth.check_host_in_room(room_id, self.hs.hostname): + if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname): room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) @@ -3011,7 +3017,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -3054,7 +3062,9 @@ class FederationHandler(BaseHandler): ) try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e @@ -3142,7 +3152,7 @@ class FederationHandler(BaseHandler): last_exception = None # type: Optional[Exception] # for each public key in the 3pid invite event - for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + for public_key_object in event_auth.get_public_keys(invite_event): try: # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 364c5cd2d3..66e40a915d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -385,6 +385,7 @@ class EventCreationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state = hs.get_state_handler() @@ -597,7 +598,7 @@ class EventCreationHandler: (e.type, e.state_key): e.event_id for e in auth_events } # Actually strip down and use the necessary auth events - auth_event_ids = self.auth.compute_auth_events( + auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, current_state_ids=auth_event_state_map, for_verification=False, @@ -1056,7 +1057,9 @@ class EventCreationHandler: assert event.content["membership"] == Membership.LEAVE else: try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -1381,7 +1384,7 @@ class EventCreationHandler: raise AuthError(403, "Redacting server ACL events is not permitted") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_map = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 835d874cee..579b1b93c5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() + self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config # Room state based off defined presets @@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler): }, ) old_room_version = await self.store.get_room_version_id(old_room_id) - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( old_room_version, tombstone_event, tombstone_context ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 266f369883..b585057ec3 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -472,7 +472,7 @@ class SpaceSummaryHandler: # If this is a request over federation, check if the host is in the room or # is in one of the spaces specified via the join rules. elif origin: - if await self._auth.check_host_in_room(room_id, origin): + if await self._event_auth_handler.check_host_in_room(room_id, origin): return True # Alternately, if the host has a user in any of the spaces specified @@ -485,7 +485,9 @@ class SpaceSummaryHandler: await self._event_auth_handler.get_rooms_that_allow_join(state_ids) ) for space_id in allowed_rooms: - if await self._auth.check_host_in_room(space_id, origin): + if await self._event_auth_handler.check_host_in_room( + space_id, origin + ): return True # otherwise, check if the room is peekable diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 350646f458..669ea462e2 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -104,7 +104,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a # time. Keyed off room_id. @@ -172,7 +172,7 @@ class BulkPushRuleEvaluator: # not having a power level event is an extreme edge case auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index dfb9b3a0fa..18e92e90d7 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -734,7 +734,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() # We don't actually check signatures in tests, so lets just create a # random key to use. @@ -846,7 +846,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): builder = EventBuilder( state=self.state, - auth=self.auth, + event_auth_handler=self._event_auth_handler, store=self.store, clock=self.clock, hostname=hostname, -- cgit 1.5.1 From 7a5873277ef456e8446a05468ccae2d81e363977 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Jul 2021 16:32:12 +0100 Subject: Add support for evicting cache entries based on last access time. (#10205) --- changelog.d/10205.feature | 1 + docs/sample_config.yaml | 62 ++++++----- mypy.ini | 1 + synapse/app/_base.py | 11 +- synapse/config/_base.pyi | 2 + synapse/config/cache.py | 70 +++++++----- synapse/util/caches/lrucache.py | 237 ++++++++++++++++++++++++++++++++++------ synapse/util/linked_list.py | 150 +++++++++++++++++++++++++ tests/util/test_lrucache.py | 46 +++++++- 9 files changed, 485 insertions(+), 95 deletions(-) create mode 100644 changelog.d/10205.feature create mode 100644 synapse/util/linked_list.py (limited to 'tests') diff --git a/changelog.d/10205.feature b/changelog.d/10205.feature new file mode 100644 index 0000000000..db3fd22587 --- /dev/null +++ b/changelog.d/10205.feature @@ -0,0 +1 @@ +Add support for evicting cache entries based on last access time. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6fcc022b47..c04aca1f42 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -673,35 +673,41 @@ retention: #event_cache_size: 10K caches: - # Controls the global cache factor, which is the default cache factor - # for all caches if a specific factor for that cache is not otherwise - # set. - # - # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment - # variable. Setting by environment variable takes priority over - # setting through the config file. - # - # Defaults to 0.5, which will half the size of all caches. - # - #global_factor: 1.0 + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 - # A dictionary of cache name to cache factor for that individual - # cache. Overrides the global cache factor for a given cache. - # - # These can also be set through environment variables comprised - # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital - # letters and underscores. Setting by environment variable - # takes priority over setting through the config file. - # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 - # - # Some caches have '*' and other characters that are not - # alphanumeric or underscores. These caches can be named with or - # without the special characters stripped. For example, to specify - # the cache factor for `*stateGroupCache*` via an environment - # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. - # - per_cache_factors: - #get_users_who_share_room_with_user: 2.0 + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + # Controls how long an entry can be in a cache without having been + # accessed before being evicted. Defaults to None, which means + # entries are never evicted based on time. + # + #expiry_time: 30m ## Database ## diff --git a/mypy.ini b/mypy.ini index c4ff0e6618..72ce932d73 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,6 +75,7 @@ files = synapse/util/daemonize.py, synapse/util/hash.py, synapse/util/iterutils.py, + synapse/util/linked_list.py, synapse/util/metrics.py, synapse/util/macaroons.py, synapse/util/module_loader.py, diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8879136881..b30571fe49 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -21,7 +21,7 @@ import socket import sys import traceback import warnings -from typing import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Awaitable, Callable, Iterable from cryptography.utils import CryptographyDeprecationWarning from typing_extensions import NoReturn @@ -41,10 +41,14 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats +from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict @@ -312,7 +316,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -async def start(hs: "synapse.server.HomeServer"): +async def start(hs: "HomeServer"): """ Start a Synapse server or worker. @@ -365,6 +369,9 @@ async def start(hs: "synapse.server.HomeServer"): load_legacy_spam_checkers(hs) + # If we've configured an expiry time for caches, start the background job now. + setup_expire_lru_cache_entries(hs) + # It is now safe to start your Synapse. hs.start_listening() hs.get_datastore().db_pool.start_profiling() diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 23ca0c83c1..06fbd1166b 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -5,6 +5,7 @@ from synapse.config import ( api, appservice, auth, + cache, captcha, cas, consent, @@ -88,6 +89,7 @@ class RootConfig: tracer: tracer.TracerConfig redis: redis.RedisConfig modules: modules.ModulesConfig + caches: cache.CacheConfig federation: federation.FederationConfig config_classes: List = ... diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 91165ee1ce..7789b40323 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -116,35 +116,41 @@ class CacheConfig(Config): #event_cache_size: 10K caches: - # Controls the global cache factor, which is the default cache factor - # for all caches if a specific factor for that cache is not otherwise - # set. - # - # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment - # variable. Setting by environment variable takes priority over - # setting through the config file. - # - # Defaults to 0.5, which will half the size of all caches. - # - #global_factor: 1.0 - - # A dictionary of cache name to cache factor for that individual - # cache. Overrides the global cache factor for a given cache. - # - # These can also be set through environment variables comprised - # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital - # letters and underscores. Setting by environment variable - # takes priority over setting through the config file. - # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 - # - # Some caches have '*' and other characters that are not - # alphanumeric or underscores. These caches can be named with or - # without the special characters stripped. For example, to specify - # the cache factor for `*stateGroupCache*` via an environment - # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. - # - per_cache_factors: - #get_users_who_share_room_with_user: 2.0 + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + # Controls how long an entry can be in a cache without having been + # accessed before being evicted. Defaults to None, which means + # entries are never evicted based on time. + # + #expiry_time: 30m """ def read_config(self, config, **kwargs): @@ -200,6 +206,12 @@ class CacheConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + expiry_time = cache_config.get("expiry_time") + if expiry_time: + self.expiry_time_msec = self.parse_duration(expiry_time) + else: + self.expiry_time_msec = None + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index d89e9d9b1d..4b9d0433ff 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading +import weakref from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -31,10 +34,19 @@ from typing import ( from typing_extensions import Literal +from twisted.internet import reactor + from synapse.config import cache as cache_config -from synapse.util import caches +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.util import Clock, caches from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.linked_list import ListNode + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) try: from pympler.asizeof import Asizer @@ -82,19 +94,126 @@ def enumerate_leaves(node, depth): yield m +P = TypeVar("P") + + +class _TimedListNode(ListNode[P]): + """A `ListNode` that tracks last access time.""" + + __slots__ = ["last_access_ts_secs"] + + def update_last_access(self, clock: Clock): + self.last_access_ts_secs = int(clock.time()) + + +# Whether to insert new cache entries to the global list. We only add to it if +# time based eviction is enabled. +USE_GLOBAL_LIST = False + +# A linked list of all cache entries, allowing efficient time based eviction. +GLOBAL_ROOT = ListNode["_Node"].create_root_node() + + +@wrap_as_background_process("LruCache._expire_old_entries") +async def _expire_old_entries(clock: Clock, expiry_seconds: int): + """Walks the global cache list to find cache entries that haven't been + accessed in the given number of seconds. + """ + + now = int(clock.time()) + node = GLOBAL_ROOT.prev_node + assert node is not None + + i = 0 + + logger.debug("Searching for stale caches") + + while node is not GLOBAL_ROOT: + # Only the root node isn't a `_TimedListNode`. + assert isinstance(node, _TimedListNode) + + if node.last_access_ts_secs > now - expiry_seconds: + break + + cache_entry = node.get_cache_entry() + next_node = node.prev_node + + # The node should always have a reference to a cache entry and a valid + # `prev_node`, as we only drop them when we remove the node from the + # list. + assert next_node is not None + assert cache_entry is not None + cache_entry.drop_from_cache() + + # If we do lots of work at once we yield to allow other stuff to happen. + if (i + 1) % 10000 == 0: + logger.debug("Waiting during drop") + await clock.sleep(0) + logger.debug("Waking during drop") + + node = next_node + + # If we've yielded then our current node may have been evicted, so we + # need to check that its still valid. + if node.prev_node is None: + break + + i += 1 + + logger.info("Dropped %d items from caches", i) + + +def setup_expire_lru_cache_entries(hs: "HomeServer"): + """Start a background job that expires all cache entries if they have not + been accessed for the given number of seconds. + """ + if not hs.config.caches.expiry_time_msec: + return + + logger.info( + "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000 + ) + + global USE_GLOBAL_LIST + USE_GLOBAL_LIST = True + + clock = hs.get_clock() + clock.looping_call( + _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000 + ) + + class _Node: - __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"] + __slots__ = [ + "_list_node", + "_global_list_node", + "_cache", + "key", + "value", + "callbacks", + "memory", + ] def __init__( self, - prev_node, - next_node, + root: "ListNode[_Node]", key, value, + cache: "weakref.ReferenceType[LruCache]", + clock: Clock, callbacks: Collection[Callable[[], None]] = (), ): - self.prev_node = prev_node - self.next_node = next_node + self._list_node = ListNode.insert_after(self, root) + self._global_list_node = None + if USE_GLOBAL_LIST: + self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) + self._global_list_node.update_last_access(clock) + + # We store a weak reference to the cache object so that this _Node can + # remove itself from the cache. If the cache is dropped we ensure we + # remove our entries in the lists. + self._cache = cache + self.key = key self.value = value @@ -116,11 +235,16 @@ class _Node: self.memory = ( _get_size_of(key) + _get_size_of(value) + + _get_size_of(self._list_node, recurse=False) + _get_size_of(self.callbacks, recurse=False) + _get_size_of(self, recurse=False) ) self.memory += _get_size_of(self.memory, recurse=False) + if self._global_list_node: + self.memory += _get_size_of(self._global_list_node, recurse=False) + self.memory += _get_size_of(self._global_list_node.last_access_ts_secs) + def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None: """Add to stored list of callbacks, removing duplicates.""" @@ -147,6 +271,32 @@ class _Node: self.callbacks = None + def drop_from_cache(self) -> None: + """Drop this node from the cache. + + Ensures that the entry gets removed from the cache and that we get + removed from all lists. + """ + cache = self._cache() + if not cache or not cache.pop(self.key, None): + # `cache.pop` should call `drop_from_lists()`, unless this Node had + # already been removed from the cache. + self.drop_from_lists() + + def drop_from_lists(self) -> None: + """Remove this node from the cache lists.""" + self._list_node.remove_from_list() + + if self._global_list_node: + self._global_list_node.remove_from_list() + + def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None: + """Moves this node to the front of all the lists its in.""" + self._list_node.move_after(cache_list_root) + if self._global_list_node: + self._global_list_node.move_after(GLOBAL_ROOT) + self._global_list_node.update_last_access(clock) + class LruCache(Generic[KT, VT]): """ @@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]): size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, + clock: Optional[Clock] = None, ): """ Args: @@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]): apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config """ + # Default `clock` to something sensible. Note that we rename it to + # `real_clock` so that mypy doesn't think its still `Optional`. + if clock is None: + real_clock = Clock(reactor) + else: + real_clock = clock + cache = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]): # this is exposed for access from outside this class self.metrics = metrics - list_root = _Node(None, None, None, None) - list_root.next_node = list_root - list_root.prev_node = list_root + # We create a single weakref to self here so that we don't need to keep + # creating more each time we create a `_Node`. + weak_ref_to_self = weakref.ref(self) + + list_root = ListNode[_Node].create_root_node() lock = threading.Lock() def evict(): while cache_len() > self.max_size: + # Get the last node in the list (i.e. the oldest node). todelete = list_root.prev_node - evicted_len = delete_node(todelete) - cache.pop(todelete.key, None) + + # The list root should always have a valid `prev_node` if the + # cache is not empty. + assert todelete is not None + + # The node should always have a reference to a cache entry, as + # we only drop the cache entry when we remove the node from the + # list. + node = todelete.get_cache_entry() + assert node is not None + + evicted_len = delete_node(node) + cache.pop(node.key, None) if metrics: metrics.inc_evictions(evicted_len) @@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): - prev_node = list_root - next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value, callbacks) - prev_node.next_node = node - next_node.prev_node = node + node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks) cache[key] = node if size_callback: @@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node): - prev_node = node.prev_node - next_node = node.next_node - prev_node.next_node = next_node - next_node.prev_node = prev_node - prev_node = list_root - next_node = prev_node.next_node - node.prev_node = prev_node - node.next_node = next_node - prev_node.next_node = node - next_node.prev_node = node - - def delete_node(node): - prev_node = node.prev_node - next_node = node.next_node - prev_node.next_node = next_node - next_node.prev_node = prev_node + def move_node_to_front(node: _Node): + node.move_to_front(real_clock, list_root) + + def delete_node(node: _Node) -> int: + node.drop_from_lists() deleted_len = 1 if size_callback: @@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_clear() -> None: - list_root.next_node = list_root - list_root.prev_node = list_root for node in cache.values(): node.run_and_clear_callbacks() + node.drop_from_lists() + + assert list_root.next_node == list_root + assert list_root.prev_node == list_root + cache.clear() if size_callback: cached_cache_len[0] = 0 @@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]): self._on_resize() return True return False + + def __del__(self) -> None: + # We're about to be deleted, so we make sure to clear up all the nodes + # and run callbacks, etc. + # + # This happens e.g. in the sync code where we have an expiring cache of + # lru caches. + self.clear() diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py new file mode 100644 index 0000000000..a456b136f0 --- /dev/null +++ b/synapse/util/linked_list.py @@ -0,0 +1,150 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A circular doubly linked list implementation. +""" + +import threading +from typing import Generic, Optional, Type, TypeVar + +P = TypeVar("P") +LN = TypeVar("LN", bound="ListNode") + + +class ListNode(Generic[P]): + """A node in a circular doubly linked list, with an (optional) reference to + a cache entry. + + The reference should only be `None` for the root node or if the node has + been removed from the list. + """ + + # A lock to protect mutating the list prev/next pointers. + _LOCK = threading.Lock() + + # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)` + # and inherit from `Generic` for some reason + __slots__ = [ + "cache_entry", + "prev_node", + "next_node", + ] + + def __init__(self, cache_entry: Optional[P] = None) -> None: + self.cache_entry = cache_entry + self.prev_node: Optional[ListNode[P]] = None + self.next_node: Optional[ListNode[P]] = None + + @classmethod + def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]": + """Create a new linked list by creating a "root" node, which is a node + that has prev_node/next_node pointing to itself and no associated cache + entry. + """ + root = cls() + root.prev_node = root + root.next_node = root + return root + + @classmethod + def insert_after( + cls: Type[LN], + cache_entry: P, + node: "ListNode[P]", + ) -> LN: + """Create a new list node that is placed after the given node. + + Args: + cache_entry: The associated cache entry. + node: The existing node in the list to insert the new entry after. + """ + new_node = cls(cache_entry) + with cls._LOCK: + new_node._refs_insert_after(node) + return new_node + + def remove_from_list(self): + """Remove this node from the list.""" + with self._LOCK: + self._refs_remove_node_from_list() + + # We drop the reference to the cache entry to break the reference cycle + # between the list node and cache entry, allowing the two to be dropped + # immediately rather than at the next GC. + self.cache_entry = None + + def move_after(self, node: "ListNode"): + """Move this node from its current location in the list to after the + given node. + """ + with self._LOCK: + # We assert that both this node and the target node is still "alive". + assert self.prev_node + assert self.next_node + assert node.prev_node + assert node.next_node + + assert self is not node + + # Remove self from the list + self._refs_remove_node_from_list() + + # Insert self back into the list, after target node + self._refs_insert_after(node) + + def _refs_remove_node_from_list(self): + """Internal method to *just* remove the node from the list, without + e.g. clearing out the cache entry. + """ + if self.prev_node is None or self.next_node is None: + # We've already been removed from the list. + return + + prev_node = self.prev_node + next_node = self.next_node + + prev_node.next_node = next_node + next_node.prev_node = prev_node + + # We set these to None so that we don't get circular references, + # allowing us to be dropped without having to go via the GC. + self.prev_node = None + self.next_node = None + + def _refs_insert_after(self, node: "ListNode"): + """Internal method to insert the node after the given node.""" + + # This method should only be called when we're not already in the list. + assert self.prev_node is None + assert self.next_node is None + + # We expect the given node to be in the list and thus have valid + # prev/next refs. + assert node.next_node + assert node.prev_node + + prev_node = node + next_node = node.next_node + + self.prev_node = prev_node + self.next_node = next_node + + prev_node.next_node = self + next_node.prev_node = self + + def get_cache_entry(self) -> Optional[P]: + """Get the cache entry, returns None if this is the root node (i.e. + cache_entry is None) or if the entry has been dropped. + """ + return self.cache_entry diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 377904e72e..6578f3411e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -15,7 +15,7 @@ from unittest.mock import Mock -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache from tests import unittest @@ -260,3 +260,47 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): self.assertEquals(cache["key3"], [3]) self.assertEquals(cache["key4"], [4]) self.assertEquals(cache["key5"], [5, 6]) + + +class TimeEvictionTestCase(unittest.HomeserverTestCase): + """Test that time based eviction works correctly.""" + + def default_config(self): + config = super().default_config() + + config.setdefault("caches", {})["expiry_time"] = "30m" + + return config + + def test_evict(self): + setup_expire_lru_cache_entries(self.hs) + + cache = LruCache(5, clock=self.hs.get_clock()) + + # Check that we evict entries we haven't accessed for 30 minutes. + cache["key1"] = 1 + cache["key2"] = 2 + + self.reactor.advance(20 * 60) + + self.assertEqual(cache.get("key1"), 1) + + self.reactor.advance(20 * 60) + + # We have only touched `key1` in the last 30m, so we expect that to + # still be in the cache while `key2` should have been evicted. + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), None) + + # Check that re-adding an expired key works correctly. + cache["key2"] = 3 + self.assertEqual(cache.get("key2"), 3) + + self.reactor.advance(20 * 60) + + self.assertEqual(cache.get("key2"), 3) + + self.reactor.advance(20 * 60) + + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), 3) -- cgit 1.5.1 From bcb0962a7250d6c1430ad42f5ed234ffea8f2468 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:08:53 +0200 Subject: Fix deactivate a user if he does not have a profile (#10252) --- changelog.d/10252.bugfix | 1 + synapse/storage/databases/main/profile.py | 8 +-- tests/rest/admin/test_user.py | 86 ++++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 changelog.d/10252.bugfix (limited to 'tests') diff --git a/changelog.d/10252.bugfix b/changelog.d/10252.bugfix new file mode 100644 index 0000000000..c8ddd14528 --- /dev/null +++ b/changelog.d/10252.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.26.0 where only users who have set profile information could be deactivated with erasure enabled. diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 9b4e95e134..ba7075caa5 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -73,20 +73,20 @@ class ProfileWorkerStore(SQLBaseStore): async def set_profile_displayname( self, user_localpart: str, new_displayname: Optional[str] ) -> None: - await self.db_pool.simple_update_one( + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, + values={"displayname": new_displayname}, desc="set_profile_displayname", ) async def set_profile_avatar_url( self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: - await self.db_pool.simple_update_one( + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a34d051734..4fccce34fd 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -939,7 +939,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -950,7 +950,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -960,7 +960,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): @@ -990,7 +990,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): @@ -1006,7 +1006,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): def test_deactivate_user_erase_true(self): """ - Test deactivating an user and set `erase` to `true` + Test deactivating a user and set `erase` to `true` """ # Get user @@ -1016,24 +1016,22 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) - # Deactivate user - body = json.dumps({"erase": True}) - + # Deactivate and erase user channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"erase": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1042,7 +1040,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1053,7 +1051,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): def test_deactivate_user_erase_false(self): """ - Test deactivating an user and set `erase` to `false` + Test deactivating a user and set `erase` to `false` """ # Get user @@ -1063,7 +1061,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1071,13 +1069,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("User1", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"erase": False}) - channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"erase": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1089,7 +1085,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1098,6 +1094,60 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", False) + def test_deactivate_user_erase_true_no_profile(self): + """ + Test deactivating a user and set `erase` to `true` + if user has no profile information (stored in the database table `profiles`). + """ + + # Users normally have an entry in `profiles`, but occasionally they are created without one. + # To test deactivation for users without a profile, we delete the profile information for our user. + self.get_success( + self.store.db_pool.simple_delete_one( + table="profiles", keyvalues={"user_id": "user"} + ) + ) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + # Deactivate and erase user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"erase": True}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) -- cgit 1.5.1 From 47e28b4031c7c5e2c87824c2b4873492b996d02e Mon Sep 17 00:00:00 2001 From: Dagfinn Ilmari Mannsåker Date: Tue, 6 Jul 2021 14:31:13 +0100 Subject: Ignore EDUs for rooms we're not in (#10317) --- changelog.d/10317.bugfix | 1 + synapse/handlers/receipts.py | 15 +++++++++++++++ synapse/handlers/typing.py | 14 ++++++++++++++ tests/handlers/test_typing.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+) create mode 100644 changelog.d/10317.bugfix (limited to 'tests') diff --git a/changelog.d/10317.bugfix b/changelog.d/10317.bugfix new file mode 100644 index 0000000000..826c269eff --- /dev/null +++ b/changelog.d/10317.bugfix @@ -0,0 +1 @@ +Fix purging rooms that other homeservers are still sending events for. Contributed by @ilmari. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index f782d9db32..0059ad0f56 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -30,6 +30,8 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() + self.event_auth_handler = hs.get_event_auth_handler() + self.hs = hs # We only need to poke the federation sender explicitly if its on the @@ -59,6 +61,19 @@ class ReceiptsHandler(BaseHandler): """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] for room_id, room_values in content.items(): + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + is_in_room = await self.event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring receipt from %s as we're not in the room", + origin, + ) + continue + for receipt_type, users in room_values.items(): for user_id, user_values in users.items(): if get_domain_from_id(user_id) != origin: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e22393adc4..c0a8364755 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -208,6 +208,7 @@ class TypingWriterHandler(FollowerTypingHandler): self.auth = hs.get_auth() self.notifier = hs.get_notifier() + self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -326,6 +327,19 @@ class TypingWriterHandler(FollowerTypingHandler): room_id = content["room_id"] user_id = content["user_id"] + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + is_in_room = await self.event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring typing update from %s as we're not in the room", + origin, + ) + return + member = RoomMember(user_id=user_id, room_id=room_id) # Check that the string is a valid user id diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f58afbc244..fa3cff598e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -38,6 +38,9 @@ U_ONION = UserID.from_string("@onion:farm") # Test room id ROOM_ID = "a-room" +# Room we're not in +OTHER_ROOM_ID = "another-room" + def _expect_edu_transaction(edu_type, content, origin="test"): return { @@ -115,6 +118,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_auth().check_user_in_room = check_user_in_room + async def check_host_in_room(room_id, server_name): + return room_id == ROOM_ID + + hs.get_event_auth_handler().check_host_in_room = check_host_in_room + def get_joined_hosts_for_room(room_id): return {member.domain for member in self.room_members} @@ -244,6 +252,35 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + def test_started_typing_remote_recv_not_in_room(self): + self.room_members = [U_APPLE, U_ONION] + + self.assertEquals(self.event_source.get_current_key(), 0) + + channel = self.make_request( + "PUT", + "/_matrix/federation/v1/send/1000000", + _make_edu_transaction_json( + "m.typing", + content={ + "room_id": OTHER_ROOM_ID, + "user_id": U_ONION.to_string(), + "typing": True, + }, + ), + federation_auth_origin=b"farm", + ) + self.assertEqual(channel.code, 200) + + self.on_new_event.assert_not_called() + + self.assertEquals(self.event_source.get_current_key(), 0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0) + ) + self.assertEquals(events[0], []) + self.assertEquals(events[1], 0) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] -- cgit 1.5.1 From f6767abc054f3461cd9a70ba096fcf9a8e640edb Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 8 Jul 2021 10:57:13 -0500 Subject: Remove functionality associated with unused historical stats tables (#9721) Fixes #9602 --- changelog.d/9721.removal | 1 + docs/room_and_user_statistics.md | 50 +---- docs/sample_config.yaml | 5 - synapse/config/stats.py | 9 - synapse/handlers/stats.py | 27 --- synapse/storage/databases/main/purge_events.py | 1 - synapse/storage/databases/main/stats.py | 291 +------------------------ synapse/storage/schema/__init__.py | 6 +- tests/handlers/test_stats.py | 203 +---------------- tests/rest/admin/test_room.py | 1 - 10 files changed, 22 insertions(+), 572 deletions(-) create mode 100644 changelog.d/9721.removal (limited to 'tests') diff --git a/changelog.d/9721.removal b/changelog.d/9721.removal new file mode 100644 index 0000000000..da2ba48c84 --- /dev/null +++ b/changelog.d/9721.removal @@ -0,0 +1 @@ +Remove functionality associated with the unused `room_stats_historical` and `user_stats_historical` tables. Contributed by @xmunoz. diff --git a/docs/room_and_user_statistics.md b/docs/room_and_user_statistics.md index e1facb38d4..cc38c890bb 100644 --- a/docs/room_and_user_statistics.md +++ b/docs/room_and_user_statistics.md @@ -1,9 +1,9 @@ Room and User Statistics ======================== -Synapse maintains room and user statistics (as well as a cache of room state), -in various tables. These can be used for administrative purposes but are also -used when generating the public room directory. +Synapse maintains room and user statistics in various tables. These can be used +for administrative purposes but are also used when generating the public room +directory. # Synapse Developer Documentation @@ -15,48 +15,8 @@ used when generating the public room directory. * **subject**: Something we are tracking stats about – currently a room or user. * **current row**: An entry for a subject in the appropriate current statistics table. Each subject can have only one. -* **historical row**: An entry for a subject in the appropriate historical - statistics table. Each subject can have any number of these. ### Overview -Stats are maintained as time series. There are two kinds of column: - -* absolute columns – where the value is correct for the time given by `end_ts` - in the stats row. (Imagine a line graph for these values) - * They can also be thought of as 'gauges' in Prometheus, if you are familiar. -* per-slice columns – where the value corresponds to how many of the occurrences - occurred within the time slice given by `(end_ts − bucket_size)…end_ts` - or `start_ts…end_ts`. (Imagine a histogram for these values) - -Stats are maintained in two tables (for each type): current and historical. - -Current stats correspond to the present values. Each subject can only have one -entry. - -Historical stats correspond to values in the past. Subjects may have multiple -entries. - -## Concepts around the management of stats - -### Current rows - -Current rows contain the most up-to-date statistics for a room. -They only contain absolute columns - -### Historical rows - -Historical rows can always be considered to be valid for the time slice and -end time specified. - -* historical rows will not exist for every time slice – they will be omitted - if there were no changes. In this case, the following assumptions can be - made to interpolate/recreate missing rows: - - absolute fields have the same values as in the preceding row - - per-slice fields are zero (`0`) -* historical rows will not be retained forever – rows older than a configurable - time will be purged. - -#### Purge - -The purging of historical rows is not yet implemented. +Stats correspond to the present values. Current rows contain the most up-to-date +statistics for a room. Each subject can only have one entry. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 71463168e3..cbbe7d58d9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2652,11 +2652,6 @@ stats: # #enabled: false - # The size of each timeslice in the room_stats_historical and - # user_stats_historical tables, as a time period. Defaults to "1d". - # - #bucket_size: 1h - # Server Notices room configuration # diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 78f61fe9da..6f253e00c0 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -38,13 +38,9 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True - self.stats_bucket_size = 86400 * 1000 stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) - self.stats_bucket_size = self.parse_duration( - stats_config.get("bucket_size", "1d") - ) if not self.stats_enabled: logger.warning(ROOM_STATS_DISABLED_WARN) @@ -59,9 +55,4 @@ class StatsConfig(Config): # correctly. # #enabled: false - - # The size of each timeslice in the room_stats_historical and - # user_stats_historical tables, as a time period. Defaults to "1d". - # - #bucket_size: 1h """ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 4e45d1da57..814d08efcb 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -45,7 +45,6 @@ class StatsHandler: self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.stats_bucket_size = hs.config.stats_bucket_size self.stats_enabled = hs.config.stats_enabled @@ -106,20 +105,6 @@ class StatsHandler: room_deltas = {} user_deltas = {} - # Then count deltas for total_events and total_event_bytes. - ( - room_count, - user_count, - ) = await self.store.get_changes_room_total_events_and_bytes( - self.pos, max_pos - ) - - for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, Counter()).update(fields) - - for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, Counter()).update(fields) - logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -181,12 +166,10 @@ class StatsHandler: event_content = {} # type: JsonDict - sender = None if event_id is not None: event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} - sender = event.sender # All the values in this dict are deltas (RELATIVE changes) room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter()) @@ -244,12 +227,6 @@ class StatsHandler: room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: room_stats_delta["invited_members"] += 1 - - if sender and self.is_mine_id(sender): - user_to_stats_deltas.setdefault(sender, Counter())[ - "invites_sent" - ] += 1 - elif membership == Membership.LEAVE: room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: @@ -279,10 +256,6 @@ class StatsHandler: room_state["is_federatable"] = ( event_content.get("m.federate", True) is True ) - if sender and self.is_mine_id(sender): - user_to_stats_deltas.setdefault(sender, Counter())[ - "rooms_created" - ] += 1 elif typ == EventTypes.JoinRules: room_state["join_rules"] = event_content.get("join_rule") elif typ == EventTypes.RoomHistoryVisibility: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 7fb7780d0f..ec6b1eb5d4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -392,7 +392,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 82a1833509..b10bee6daf 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -26,7 +26,6 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import StoreError from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state_deltas import StateDeltasStore -from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -49,14 +48,6 @@ ABSOLUTE_STATS_FIELDS = { "user": ("joined_rooms",), } -# these fields are per-timeslice and so should be reset to 0 upon a new slice -# You can draw these stats on a histogram. -# Example: number of events sent locally during a time slice -PER_SLICE_FIELDS = { - "room": ("total_events", "total_event_bytes"), - "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), -} - TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} # these are the tables (& ID columns) which contain our actual subjects @@ -106,7 +97,6 @@ class StatsStore(StateDeltasStore): self.server_name = hs.hostname self.clock = self.hs.get_clock() self.stats_enabled = hs.config.stats_enabled - self.stats_bucket_size = hs.config.stats_bucket_size self.stats_delta_processing_lock = DeferredLock() @@ -122,22 +112,6 @@ class StatsStore(StateDeltasStore): self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") self.db_pool.updates.register_noop_background_update("populate_stats_prepare") - def quantise_stats_time(self, ts): - """ - Quantises a timestamp to be a multiple of the bucket size. - - Args: - ts (int): the timestamp to quantise, in milliseconds since the Unix - Epoch - - Returns: - int: a timestamp which - - is divisible by the bucket size; - - is no later than `ts`; and - - is the largest such timestamp. - """ - return (ts // self.stats_bucket_size) * self.stats_bucket_size - async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. @@ -288,56 +262,6 @@ class StatsStore(StateDeltasStore): desc="update_room_state", ) - async def get_statistics_for_subject( - self, stats_type: str, stats_id: str, start: str, size: int = 100 - ) -> List[dict]: - """ - Get statistics for a given subject. - - Args: - stats_type: The type of subject - stats_id: The ID of the subject (e.g. room_id or user_id) - start: Pagination start. Number of entries, not timestamp. - size: How many entries to return. - - Returns: - A list of dicts, where the dict has the keys of - ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". - """ - return await self.db_pool.runInteraction( - "get_statistics_for_subject", - self._get_statistics_for_subject_txn, - stats_type, - stats_id, - start, - size, - ) - - def _get_statistics_for_subject_txn( - self, txn, stats_type, stats_id, start, size=100 - ): - """ - Transaction-bound version of L{get_statistics_for_subject}. - """ - - table, id_col = TYPE_TO_TABLE[stats_type] - selected_columns = list( - ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] - ) - - slice_list = self.db_pool.simple_select_list_paginate_txn( - txn, - table + "_historical", - "end_ts", - start, - size, - retcols=selected_columns + ["bucket_size", "end_ts"], - keyvalues={id_col: stats_id}, - order_direction="DESC", - ) - - return slice_list - @cached() async def get_earliest_token_for_stats( self, stats_type: str, id: str @@ -451,14 +375,10 @@ class StatsStore(StateDeltasStore): table, id_col = TYPE_TO_TABLE[stats_type] - quantised_ts = self.quantise_stats_time(int(ts)) - end_ts = quantised_ts + self.stats_bucket_size - # Lets be paranoid and check that all the given field names are known abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] - slice_field_names = PER_SLICE_FIELDS[stats_type] for field in chain(fields.keys(), absolute_field_overrides.keys()): - if field not in abs_field_names and field not in slice_field_names: + if field not in abs_field_names: # guard against potential SQL injection dodginess raise ValueError( "%s is not a recognised field" @@ -491,20 +411,6 @@ class StatsStore(StateDeltasStore): additive_relatives=deltas_of_absolute_fields, ) - per_slice_additive_relatives = { - key: fields.get(key, 0) for key in slice_field_names - } - self._upsert_copy_from_table_with_additive_relatives_txn( - txn=txn, - into_table=table + "_historical", - keyvalues={id_col: stats_id}, - extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, - extra_dst_keyvalues={"end_ts": end_ts}, - additive_relatives=per_slice_additive_relatives, - src_table=table + "_current", - copy_columns=abs_field_names, - ) - def _upsert_with_additive_relatives_txn( self, txn, table, keyvalues, absolutes, additive_relatives ): @@ -572,201 +478,6 @@ class StatsStore(StateDeltasStore): current_row.update(absolutes) self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) - def _upsert_copy_from_table_with_additive_relatives_txn( - self, - txn, - into_table, - keyvalues, - extra_dst_keyvalues, - extra_dst_insvalues, - additive_relatives, - src_table, - copy_columns, - ): - """Updates the historic stats table with latest updates. - - This involves copying "absolute" fields from the `_current` table, and - adding relative fields to any existing values. - - Args: - txn: Transaction - into_table (str): The destination table to UPSERT the row into - keyvalues (dict[str, any]): Row-identifying key values - extra_dst_keyvalues (dict[str, any]): Additional keyvalues - for `into_table`. - extra_dst_insvalues (dict[str, any]): Additional values to insert - on new row creation for `into_table`. - additive_relatives (dict[str, any]): Fields that will be added onto - if existing row present. (Must be disjoint from copy_columns.) - src_table (str): The source table to copy from - copy_columns (iterable[str]): The list of columns to copy - """ - if self.database_engine.can_native_upsert: - ins_columns = chain( - keyvalues, - copy_columns, - additive_relatives, - extra_dst_keyvalues, - extra_dst_insvalues, - ) - sel_exprs = chain( - keyvalues, - copy_columns, - ( - "?" - for _ in chain( - additive_relatives, extra_dst_keyvalues, extra_dst_insvalues - ) - ), - ) - keyvalues_where = ("%s = ?" % f for f in keyvalues) - - sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) - sets_ar = ( - "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) - for f in additive_relatives - ) - - sql = """ - INSERT INTO %(into_table)s (%(ins_columns)s) - SELECT %(sel_exprs)s - FROM %(src_table)s - WHERE %(keyvalues_where)s - ON CONFLICT (%(keyvalues)s) - DO UPDATE SET %(sets)s - """ % { - "into_table": into_table, - "ins_columns": ", ".join(ins_columns), - "sel_exprs": ", ".join(sel_exprs), - "keyvalues_where": " AND ".join(keyvalues_where), - "src_table": src_table, - "keyvalues": ", ".join( - chain(keyvalues.keys(), extra_dst_keyvalues.keys()) - ), - "sets": ", ".join(chain(sets_cc, sets_ar)), - } - - qargs = list( - chain( - additive_relatives.values(), - extra_dst_keyvalues.values(), - extra_dst_insvalues.values(), - keyvalues.values(), - ) - ) - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, into_table) - src_row = self.db_pool.simple_select_one_txn( - txn, src_table, keyvalues, copy_columns - ) - all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.db_pool.simple_select_one_txn( - txn, - into_table, - keyvalues=all_dest_keyvalues, - retcols=list(chain(additive_relatives.keys(), copy_columns)), - allow_none=True, - ) - - if dest_current_row is None: - merged_dict = { - **keyvalues, - **extra_dst_keyvalues, - **extra_dst_insvalues, - **src_row, - **additive_relatives, - } - self.db_pool.simple_insert_txn(txn, into_table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - src_row[key] = dest_current_row[key] + val - self.db_pool.simple_update_txn( - txn, into_table, all_dest_keyvalues, src_row - ) - - async def get_changes_room_total_events_and_bytes( - self, min_pos: int, max_pos: int - ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: - """Fetches the counts of events in the given range of stream IDs. - - Args: - min_pos - max_pos - - Returns: - Mapping of room ID to field changes. - """ - - return await self.db_pool.runInteraction( - "stats_incremental_total_events_and_bytes", - self.get_changes_room_total_events_and_bytes_txn, - min_pos, - max_pos, - ) - - def get_changes_room_total_events_and_bytes_txn( - self, txn, low_pos: int, high_pos: int - ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: - """Gets the total_events and total_event_bytes counts for rooms and - senders, in a range of stream_orderings (including backfilled events). - - Args: - txn - low_pos: Low stream ordering - high_pos: High stream ordering - - Returns: - The room and user deltas for total_events/total_event_bytes in the - format of `stats_id` -> fields - """ - - if low_pos >= high_pos: - # nothing to do here. - return {}, {} - - if isinstance(self.database_engine, PostgresEngine): - new_bytes_expression = "OCTET_LENGTH(json)" - else: - new_bytes_expression = "LENGTH(CAST(json AS BLOB))" - - sql = """ - SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.room_id - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - room_deltas = { - room_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for room_id, new_events, new_bytes in txn - } - - sql = """ - SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.sender - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - user_deltas = { - user_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for user_id, new_events, new_bytes in txn - if self.hs.is_mine_id(user_id) - } - - return room_deltas, user_deltas - async def _calculate_and_set_initial_state_for_room( self, room_id: str ) -> Tuple[dict, dict, int]: diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 0a53b73ccc..36340a652a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 60 +SCHEMA_VERSION = 61 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -21,6 +21,10 @@ older versions of Synapse). See `README.md `_ for more information on how this works. + +Changes in SCHEMA_VERSION = 61: + - The `user_stats_historical` and `room_stats_historical` tables are not written and + are not read (previously, they were written but not read). """ diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index c9d4fd9336..e4059acda3 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -88,16 +88,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): def _get_current_stats(self, stats_type, stat_id): table, id_col = stats.TYPE_TO_TABLE[stats_type] - cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( - stats.PER_SLICE_FIELDS[stats_type] - ) - - end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) + cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) return self.get_success( self.store.db_pool.simple_select_one( - table + "_historical", - {id_col: stat_id, end_ts: end_ts}, + table + "_current", + {id_col: stat_id}, cols, allow_none=True, ) @@ -156,115 +152,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") - def test_initial_earliest_token(self): - """ - Ingestion via notify_new_event will ignore tokens that the background - update have already processed. - """ - - self.reactor.advance(86401) - - self.hs.config.stats_enabled = False - self.handler.stats_enabled = False - - u1 = self.register_user("u1", "pass") - u1_token = self.login("u1", "pass") - - u2 = self.register_user("u2", "pass") - u2_token = self.login("u2", "pass") - - u3 = self.register_user("u3", "pass") - u3_token = self.login("u3", "pass") - - room_1 = self.helper.create_room_as(u1, tok=u1_token) - self.helper.send_state( - room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token - ) - - # Begin the ingestion by creating the temp tables. This will also store - # the position that the deltas should begin at, once they take over. - self.hs.config.stats_enabled = True - self.handler.stats_enabled = True - self.store.db_pool.updates._all_done = False - self.get_success( - self.store.db_pool.simple_update_one( - table="stats_incremental_position", - keyvalues={}, - updatevalues={"stream_id": 0}, - ) - ) - - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "populate_stats_prepare", "progress_json": "{}"}, - ) - ) - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - # Now, before the table is actually ingested, add some more events. - self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) - self.helper.join(room=room_1, user=u2, tok=u2_token) - - # orig_delta_processor = self.store. - - # Now do the initial ingestion. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_stats_cleanup", - "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", - }, - ) - ) - - self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - self.reactor.advance(86401) - - # Now add some more events, triggering ingestion. Because of the stream - # position being set to before the events sent in the middle, a simpler - # implementation would reprocess those events, and say there were four - # users, not three. - self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) - self.helper.join(room=room_1, user=u3, tok=u3_token) - - # self.handler.notify_new_event() - - # We need to let the delta processor advance… - self.reactor.advance(10 * 60) - - # Get the slices! There should be two -- day 1, and day 2. - r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) - - self.assertEqual(len(r), 2) - - # The oldest has 2 joined members - self.assertEqual(r[-1]["joined_members"], 2) - - # The newest has 3 - self.assertEqual(r[0]["joined_members"], 3) - def test_create_user(self): """ When we create a user, it should have statistics already ready. @@ -296,22 +183,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNotNone(r1stats) self.assertIsNotNone(r2stats) - # contains the default things you'd expect in a fresh room - self.assertEqual( - r1stats["total_events"], - EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM, - "Wrong number of total_events in new room's stats!" - " You may need to update this if more state events are added to" - " the room creation process.", - ) - self.assertEqual( - r2stats["total_events"], - EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, - "Wrong number of total_events in new room's stats!" - " You may need to update this if more state events are added to" - " the room creation process.", - ) - self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM ) @@ -327,24 +198,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(r2stats["invited_members"], 0) self.assertEqual(r2stats["banned_members"], 0) - def test_send_message_increments_total_events(self): - """ - When we send a message, it increments total_events. - """ - - self._perform_background_initial_update() - - u1 = self.register_user("u1", "pass") - u1token = self.login("u1", "pass") - r1 = self.helper.create_room_as(u1, tok=u1token) - r1stats_ante = self._get_current_stats("room", r1) - - self.helper.send(r1, "hiss", tok=u1token) - - r1stats_post = self._get_current_stats("room", r1) - - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) - def test_updating_profile_information_does_not_increase_joined_members_count(self): """ Check that the joined_members count does not increase when a user changes their @@ -378,7 +231,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_send_state_event_nonoverwriting(self): """ - When we send a non-overwriting state event, it increments total_events AND current_state_events + When we send a non-overwriting state event, it increments current_state_events """ self._perform_background_initial_update() @@ -399,44 +252,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, ) - def test_send_state_event_overwriting(self): - """ - When we send an overwriting state event, it increments total_events ONLY - """ - - self._perform_background_initial_update() - - u1 = self.register_user("u1", "pass") - u1token = self.login("u1", "pass") - r1 = self.helper.create_room_as(u1, tok=u1token) - - self.helper.send_state( - r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" - ) - - r1stats_ante = self._get_current_stats("room", r1) - - self.helper.send_state( - r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby" - ) - - r1stats_post = self._get_current_stats("room", r1) - - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) - self.assertEqual( - r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], - 0, - ) - def test_join_first_time(self): """ - When a user joins a room for the first time, total_events, current_state_events and + When a user joins a room for the first time, current_state_events and joined_members should increase by exactly 1. """ @@ -455,7 +278,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, @@ -466,7 +288,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_join_after_leave(self): """ - When a user joins a room after being previously left, total_events and + When a user joins a room after being previously left, joined_members should increase by exactly 1. current_state_events should not increase. left_members should decrease by exactly 1. @@ -490,7 +312,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -504,7 +325,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_invited(self): """ - When a user invites another user, current_state_events, total_events and + When a user invites another user, current_state_events and invited_members should increase by exactly 1. """ @@ -522,7 +343,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, @@ -533,7 +353,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_join_after_invite(self): """ - When a user joins a room after being invited, total_events and + When a user joins a room after being invited and joined_members should increase by exactly 1. current_state_events should not increase. invited_members should decrease by exactly 1. @@ -556,7 +376,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -570,7 +389,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_left(self): """ - When a user leaves a room after joining, total_events and + When a user leaves a room after joining and left_members should increase by exactly 1. current_state_events should not increase. joined_members should decrease by exactly 1. @@ -593,7 +412,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -607,7 +425,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_banned(self): """ - When a user is banned from a room after joining, total_events and + When a user is banned from a room after joining and left_members should increase by exactly 1. current_state_events should not increase. banned_members should decrease by exactly 1. @@ -630,7 +448,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ee071c2477..959d3cea77 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1753,7 +1753,6 @@ PURGE_TABLES = [ "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", -- cgit 1.5.1 From 19d0401c56a8f31441c65e62ffd688f615536d76 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 12 Jul 2021 11:21:04 -0400 Subject: Additional unit tests for spaces summary. (#10305) --- changelog.d/10305.misc | 1 + tests/handlers/test_space_summary.py | 204 ++++++++++++++++++++++++++++++++++- 2 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10305.misc (limited to 'tests') diff --git a/changelog.d/10305.misc b/changelog.d/10305.misc new file mode 100644 index 0000000000..8488d47f6f --- /dev/null +++ b/changelog.d/10305.misc @@ -0,0 +1 @@ +Additional unit tests for the spaces summary API. diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 9771d3fb3b..faed1f1a18 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,7 +14,7 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock -from synapse.api.constants import EventContentFields, RoomTypes +from synapse.api.constants import EventContentFields, JoinRules, RoomTypes from synapse.api.errors import AuthError from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin @@ -178,3 +178,205 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, [self.space]) self._assert_events(result, [(self.space, self.room)]) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + self._assert_rooms(result, [self.space, self.room, subspace, subroom]) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, room2), + (self.space, subspace), + (subspace, subroom), + (subspace, self.room), + (subspace, room2), + ], + ) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + # Return some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + rooms = [ + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + { + "room_id": subroom, + "world_readable": True, + }, + ] + event_content = {"via": [fed_hostname]} + events = [ + { + "room_id": subspace, + "state_key": subroom, + "content": event_content, + }, + ] + return rooms, events + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms(result, [self.space, self.room, subspace, subroom]) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, subspace), + (subspace, subroom), + ], + ) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + # Note that these entries are brief, but should contain enough info. + rooms = [ + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [], + }, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [self.room], + }, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ] + + # Place each room in the sub-space. + event_content = {"via": [fed_hostname]} + events = [ + { + "room_id": subspace, + "state_key": room["room_id"], + "content": event_content, + } + for room in rooms + ] + + # Also include the subspace. + rooms.insert( + 0, + { + "room_id": subspace, + "world_readable": True, + }, + ) + return rooms, events + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms( + result, + [ + self.space, + self.room, + subspace, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, subspace), + (subspace, restricted_room), + (subspace, restricted_accessible_room), + (subspace, world_readable_room), + (subspace, joined_room), + ], + ) -- cgit 1.5.1 From 89cfc3dd9849b0580146151098ad039a7680c63f Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 12:43:15 +0200 Subject: [pyupgrade] `tests/` (#10347) --- changelog.d/10347.misc | 1 + tests/config/test_load.py | 4 ++-- tests/handlers/test_profile.py | 2 +- .../http/federation/test_matrix_federation_agent.py | 2 +- tests/http/test_fedclient.py | 8 +++----- tests/replication/_base.py | 6 +++--- tests/replication/test_multi_media_repo.py | 4 ++-- tests/replication/test_sharded_event_persister.py | 6 +++--- tests/rest/admin/test_admin.py | 6 ++---- tests/rest/admin/test_room.py | 20 ++++++++++---------- tests/rest/client/v1/test_rooms.py | 14 +++++++------- tests/rest/client/v2_alpha/test_relations.py | 2 +- tests/rest/client/v2_alpha/test_report_event.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_profile.py | 12 ++---------- tests/storage/test_purge.py | 2 +- tests/storage/test_room.py | 2 +- tests/test_types.py | 4 +--- tests/unittest.py | 2 +- 20 files changed, 45 insertions(+), 58 deletions(-) create mode 100644 changelog.d/10347.misc (limited to 'tests') diff --git a/changelog.d/10347.misc b/changelog.d/10347.misc new file mode 100644 index 0000000000..b2275a1350 --- /dev/null +++ b/changelog.d/10347.misc @@ -0,0 +1 @@ +Run `pyupgrade` on the codebase. \ No newline at end of file diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ebe2c05165..903c69127d 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) @@ -120,7 +120,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def generate_config_and_remove_lines_containing(self, needle): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: contents = f.readlines() contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index cdb41101b3..2928c4f48c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -103,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - (self.get_success(self.store.get_profile_displayname(self.frank.localpart))) + self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) def test_set_my_name_if_disabled(self): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index e45980316b..a37bce08c3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -273,7 +273,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode("ascii")) + request.write(b'{ "a": 1 }') request.finish() self.reactor.pump((0.1,)) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ed9a884d76..d9a8b077d3 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -102,7 +102,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode("ascii") + res_json = b'{ "a": 1 }' protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -339,10 +339,8 @@ class FederationClientTests(HomeserverTestCase): # Send it the HTTP response client.dataReceived( - ( - b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n" - ) + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" ) # Push by enough to time it out diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 624bd1b927..386ea70a25 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -550,12 +550,12 @@ class FakeRedisPubSubProtocol(Protocol): if obj is None: return "$-1\r\n" if isinstance(obj, str): - return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + return f"${len(obj)}\r\n{obj}\r\n" if isinstance(obj, int): - return ":{val}\r\n".format(val=obj) + return f":{obj}\r\n" if isinstance(obj, (list, tuple)): items = "".join(self.encode(a) for a in obj) - return "*{len}\r\n{items}".format(len=len(obj), items=items) + return f"*{len(obj)}\r\n{items}" raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 76e6644353..b42f1288eb 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -70,7 +70,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, FakeSite(resource), "GET", - "/{}/{}".format(target, media_id), + f"/{target}/{media_id}", shorthand=False, access_token=self.access_token, await_result=False, @@ -113,7 +113,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5eca5c165d..f3615af97e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -211,7 +211,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) @@ -241,7 +241,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(vector_clock_token), + f"/sync?since={vector_clock_token}", access_token=access_token, ) @@ -266,7 +266,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 2f7090e554..a7c6e595b9 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Create a new group channel = self.make_request( "POST", - "/create_group".encode("ascii"), + b"/create_group", access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request( - "GET", "/joined_groups".encode("ascii"), access_token=access_token - ) + channel = self.make_request("GET", b"/joined_groups", access_token=access_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 959d3cea77..17ec8bfd3b 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") def _assert_peek(self, room_id, expect_code): """Assert that the admin user can (or cannot) peek into the room.""" @@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") class RoomTestCase(unittest.HomeserverTestCase): @@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.public_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" def test_requester_is_no_admin(self): """ @@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Join user to room. - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e94566ffd7..3df070c936 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/join", content={"reason": reason}, access_token=self.second_tok, ) @@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) @@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/kick", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) @@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/ban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/unban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/invite", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 856aa8682f..2e2f94742e 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index 1ec6b05e5b..a76a6fef1e 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) resp = self.helper.send(self.room_id, tok=self.admin_user_tok) self.event_id = resp["event_id"] - self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" def test_reason_str_and_score_int(self): data = {"reason": "this makes me sad", "score": -100} diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 95e7075841..2d6b49692e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote("\u2603".encode("utf8")).encode("ascii") + filename = parse.quote("\u2603".encode()).encode("ascii") channel = self._req( b"inline; filename*=utf-8''" + filename + self.test_image.extension ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 41bef62ca8..43628ce44f 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -59,5 +59,5 @@ class DirectoryStoreTestCase(HomeserverTestCase): self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (self.get_success(self.store.get_association_from_room_alias(self.alias))) + self.get_success(self.store.get_association_from_room_alias(self.alias)) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 8a446da848..a1ba99ff14 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -45,11 +45,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_displayname(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) def test_avatar_url(self): @@ -76,9 +72,5 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_avatar_url(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) ) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 54c5b470c7..e5574063f1 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -75,7 +75,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format(token.topological + 1, token.stream + 1) + event = f"t{token.topological + 1}-{token.stream + 1}" # Purge everything before this topological token f = self.get_failure( diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 70257bf210..31ce7f6252 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -49,7 +49,7 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room_unknown_room(self): - self.assertIsNone((self.get_success(self.store.get_room("!uknown:test")))) + self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) def test_get_room_with_stats(self): self.assertDictContainsSubset( diff --git a/tests/test_types.py b/tests/test_types.py index d7881021d3..0d0c00d97a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -103,6 +103,4 @@ class MapUsernameTestCase(unittest.TestCase): def testNonAscii(self): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") - self.assertEqual( - map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" - ) + self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") diff --git a/tests/unittest.py b/tests/unittest.py index 74db7c08f1..907b94b10a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -140,7 +140,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))("Assert error for '.{}':".format(key)) from e + raise (type(e))(f"Assert error for '.{key}':") from e def assert_dict(self, required, actual): """Does a partial assert of a dict. -- cgit 1.5.1 From 93729719b8451493e1df9930feb9f02f14ea5cef Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 12:52:58 +0200 Subject: Use inline type hints in `tests/` (#10350) This PR is tantamount to running: python3.8 -m com2ann -v 6 tests/ (com2ann requires python 3.8 to run) --- changelog.d/10350.misc | 1 + tests/events/test_presence_router.py | 6 +++--- tests/module_api/test_api.py | 16 ++++++++-------- tests/replication/_base.py | 12 ++++++------ tests/replication/tcp/streams/test_events.py | 14 +++++++------- tests/replication/tcp/streams/test_receipts.py | 4 ++-- tests/replication/tcp/streams/test_typing.py | 4 ++-- tests/replication/test_multi_media_repo.py | 2 +- tests/rest/client/test_third_party_rules.py | 4 ++-- tests/rest/client/v1/test_login.py | 14 ++++++-------- tests/server.py | 8 +++++--- tests/storage/test_background_update.py | 4 +--- tests/storage/test_id_generators.py | 6 +++--- tests/test_state.py | 2 +- tests/test_utils/html_parsers.py | 6 +++--- tests/unittest.py | 2 +- tests/util/caches/test_descriptors.py | 2 +- tests/util/test_itertools.py | 18 +++++++++--------- 18 files changed, 62 insertions(+), 63 deletions(-) create mode 100644 changelog.d/10350.misc (limited to 'tests') diff --git a/changelog.d/10350.misc b/changelog.d/10350.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10350.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 875b0d0a11..c4ad33194d 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -152,7 +152,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_one_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "boop") @@ -274,7 +274,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence_updates, _ = sync_presence(self, self.other_user_id) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "I'm online!") @@ -320,7 +320,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 2c68b9a13c..81d9e2f484 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -100,9 +100,9 @@ class ModuleApiTestCase(HomeserverTestCase): "content": content, "sender": user_id, } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.message") self.assertEqual(event.room_id, room_id) @@ -136,9 +136,9 @@ class ModuleApiTestCase(HomeserverTestCase): "sender": user_id, "state_key": "", } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.power_levels") self.assertEqual(event.room_id, room_id) @@ -281,7 +281,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] @@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 386ea70a25..e9fd991718 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() - self.server = server_factory.buildProtocol( + self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( None - ) # type: ServerReplicationStreamProtocol + ) # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" @@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): fetching updates for given stream. """ - path = request.path # type: bytes # type: ignore + path: bytes = request.path # type: ignore self.assertRegex( path, br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" @@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + servlets: List[Callable[[HomeServer, JsonResource], None]] = [] def setUp(self): super().setUp() @@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler): super().__init__(hs) # list of received (stream_name, token, row) tuples - self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] + self.received_rdata_rows: List[Tuple[str, int, Any]] = [] async def on_rdata(self, stream_name, instance_name, token, rows): await super().on_rdata(stream_name, instance_name, token, rows) @@ -484,7 +484,7 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" - transport = None # type: Optional[FakeTransport] + transport: Optional[FakeTransport] = None def __init__(self, server: FakeRedisPubSubServer): self._server = server diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f51fa0a79e..666008425a 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) events = [ self._inject_state_event(sender=OTHER_USER) @@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual(row.data.event_id, pl_event.event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for stream_name, _, row in received_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) - events = [] # type: List[EventBase] + events: List[EventBase] = [] for user in user_ids: events.extend( self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) @@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual(row.data.event_id, pl_events[i].event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for _ in range(STATES_PER_USER + 1): stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 7f5d932f0b..38e292c1ab 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) @@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room2:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index ecd360c2d0..3ff5afc6e5 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) @@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index b42f1288eb..ffa425328f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request logger = logging.getLogger(__name__) -test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory] +test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e1fe72fc5d..c5e1c5458b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -233,11 +233,11 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "content": content, "sender": self.user_id, } - event = self.get_success( + event: EventBase = self.get_success( current_rules_module().module_api.create_and_send_event_into_room( event_dict ) - ) # type: EventBase + ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 605b952316..7eba69642a 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # stick the flows results in a dict by type - flow_results = {} # type: Dict[str, Any] + flow_results: Dict[str, Any] = {} for f in channel.json_body["flows"]: flow_type = f["type"] self.assertNotIn( @@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.close() # there should be a link for each href - returned_idps = [] # type: List[str] + returned_idps: List[str] = [] for link in p.links: path, query = link.split("?", 1) self.assertEqual(path, "pick_idp") @@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") assert cookie_headers - cookies = {} # type: Dict[str, str] + cookies: Dict[str, str] = {} for h in cookie_headers: key, value = h.split(";")[0].split("=", maxsplit=1) cookies[key] = value @@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode( - payload, secret, self.jwt_algorithm - ) # type: Union[str, bytes] + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) if isinstance(result, bytes): return result.decode("ascii") return result @@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") if isinstance(result, bytes): return result.decode("ascii") return result @@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie - cookies = {} # type: Dict[str,str] + cookies: Dict[str, str] = {} channel.extract_cookies(cookies) self.assertIn("username_mapping_session", cookies) session_id = cookies["username_mapping_session"] diff --git a/tests/server.py b/tests/server.py index f32d8dc375..6fddd3b305 100644 --- a/tests/server.py +++ b/tests/server.py @@ -52,7 +52,7 @@ class FakeChannel: _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) _ip = attr.ib(type=str, default="127.0.0.1") - _producer = None # type: Optional[Union[IPullProducer, IPushProducer]] + _producer: Optional[Union[IPullProducer, IPushProducer]] = None @property def json_body(self): @@ -316,8 +316,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] - lookups = self.lookups = {} # type: Dict[str, str] - self._thread_callbacks = deque() # type: Deque[Callable[[], None]] + self.lookups: Dict[str, str] = {} + self._thread_callbacks: Deque[Callable[[], None]] = deque() + + lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 069db0edc4..0da42b5ac5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -7,9 +7,7 @@ from tests import unittest class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = ( - self.hs.get_datastore().db_pool.updates - ) # type: BackgroundUpdater + self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 792b1c44c1..7486078284 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/test_state.py b/tests/test_state.py index 62f7095873..780eba823c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): self.store.register_events(graph.walk()) - context_store = {} # type: dict[str, EventContext] + context_store: dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index 1fbb38f4be..e878af5f12 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -23,13 +23,13 @@ class TestHtmlParser(HTMLParser): super().__init__() # a list of links found in the doc - self.links = [] # type: List[str] + self.links: List[str] = [] # the values of any hidden s: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] + self.hiddens: Dict[str, Optional[str]] = {} # the values of any radio buttons: map from name to list of values - self.radios = {} # type: Dict[str, List[Optional[str]]] + self.radios: Dict[str, List[Optional[str]]] = {} def handle_starttag( self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] diff --git a/tests/unittest.py b/tests/unittest.py index 907b94b10a..c6d9064423 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase): if not isinstance(deferred, Deferred): return d - results = [] # type: list + results: list = [] deferred.addBoth(results.append) self.pump(by=by) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 0277998cbe..39947a166b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -174,7 +174,7 @@ class DescriptorTestCase(unittest.TestCase): return self.result obj = Cls() - callbacks = set() # type: Set[str] + callbacks: Set[str] = set() # set off an asynchronous request obj.result = origin_d = defer.Deferred() diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index e712eb42ea..3c0ddd4f18 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase): ) def test_empty_input(self): - parts = chunk_seq([], 5) # type: Iterable[Sequence] + parts: Iterable[Sequence] = chunk_seq([], 5) self.assertEqual( list(parts), @@ -56,13 +56,13 @@ class SortTopologically(TestCase): def test_empty(self): "Test that an empty graph works correctly" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} self.assertEqual(list(sorted_topologically([], graph)), []) def test_handle_empty_graph(self): "Test that a graph where a node doesn't have an entry is treated as empty" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -70,7 +70,7 @@ class SortTopologically(TestCase): def test_disconnected(self): "Test that a graph with no edges work" - graph = {1: [], 2: []} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: []} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -78,19 +78,19 @@ class SortTopologically(TestCase): def test_linear(self): "Test that a simple `4 -> 3 -> 2 -> 1` graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_subset(self): "Test that only sorting a subset of the graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) def test_fork(self): "Test that a forked graph works" - graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # always get the same one. @@ -98,12 +98,12 @@ class SortTopologically(TestCase): def test_duplicates(self): "Test that a graph with duplicate edges work" - graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_multiple_paths(self): "Test that a graph with multiple paths between two nodes work" - graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From 2d16e69b4bf09b5274a8fa15c8ca4719db8366c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 13 Jul 2021 08:59:27 -0400 Subject: Show all joinable rooms in the spaces summary. (#10298) Previously only world-readable rooms were shown. This means that rooms which are public, knockable, or invite-only with a pending invitation, are included in a space summary. It also applies the same logic to the experimental room version from MSC3083 -- if a user has access to the proper allowed rooms then it is shown in the spaces summary. This change is made per MSC3173 allowing stripped state of a room to be shown to any potential room joiner. --- changelog.d/10298.feature | 1 + changelog.d/10305.feature | 1 + changelog.d/10305.misc | 1 - synapse/handlers/space_summary.py | 68 +++++++--- synapse/storage/databases/main/roommember.py | 13 +- tests/handlers/test_space_summary.py | 191 ++++++++++++++++++++++++--- 6 files changed, 237 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10298.feature create mode 100644 changelog.d/10305.feature delete mode 100644 changelog.d/10305.misc (limited to 'tests') diff --git a/changelog.d/10298.feature b/changelog.d/10298.feature new file mode 100644 index 0000000000..7059db5075 --- /dev/null +++ b/changelog.d/10298.feature @@ -0,0 +1 @@ +The spaces summary API now returns any joinable rooms, not only rooms which are world-readable. diff --git a/changelog.d/10305.feature b/changelog.d/10305.feature new file mode 100644 index 0000000000..7059db5075 --- /dev/null +++ b/changelog.d/10305.feature @@ -0,0 +1 @@ +The spaces summary API now returns any joinable rooms, not only rooms which are world-readable. diff --git a/changelog.d/10305.misc b/changelog.d/10305.misc deleted file mode 100644 index 8488d47f6f..0000000000 --- a/changelog.d/10305.misc +++ /dev/null @@ -1 +0,0 @@ -Additional unit tests for the spaces summary API. diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index b585057ec3..366e6211e5 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -24,6 +24,7 @@ from synapse.api.constants import ( EventContentFields, EventTypes, HistoryVisibility, + JoinRules, Membership, RoomTypes, ) @@ -150,14 +151,21 @@ class SpaceSummaryHandler: # The room should only be included in the summary if: # a. the user is in the room; # b. the room is world readable; or - # c. the user is in a space that has been granted access to - # the room. + # c. the user could join the room, e.g. the join rules + # are set to public or the user is in a space that + # has been granted access to the room. # # Note that we know the user is not in the root room (which is # why the remote call was made in the first place), but the user # could be in one of the children rooms and we just didn't know # about the link. - include_room = room.get("world_readable") is True + + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + include_room = ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ) # Check if the user is a member of any of the allowed spaces # from the response. @@ -420,9 +428,8 @@ class SpaceSummaryHandler: It should be included if: - * The requester is joined or invited to the room. - * The requester can join without an invite (per MSC3083). - * The origin server has any user that is joined or invited to the room. + * The requester is joined or can join the room (per MSC3173). + * The origin server has any user that is joined or can join the room. * The history visibility is set to world readable. Args: @@ -441,13 +448,39 @@ class SpaceSummaryHandler: # If there's no state for the room, it isn't known. if not state_ids: + # The user might have a pending invite for the room. + if requester and await self._store.get_invite_for_local_user_in_room( + requester, room_id + ): + return True + logger.info("room %s is unknown, omitting from summary", room_id) return False room_version = await self._store.get_room_version(room_id) - # if we have an authenticated requesting user, first check if they are able to view - # stripped state in the room. + # Include the room if it has join rules of public or knock. + join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) + if join_rules_event_id: + join_rules_event = await self._store.get_event(join_rules_event_id) + join_rule = join_rules_event.content.get("join_rule") + if join_rule == JoinRules.PUBLIC or ( + room_version.msc2403_knocking and join_rule == JoinRules.KNOCK + ): + return True + + # Include the room if it is peekable. + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + # Otherwise we need to check information specific to the user or server. + + # If we have an authenticated requesting user, check if they are a member + # of the room (or can join the room). if requester: member_event_id = state_ids.get((EventTypes.Member, requester), None) @@ -470,9 +503,11 @@ class SpaceSummaryHandler: return True # If this is a request over federation, check if the host is in the room or - # is in one of the spaces specified via the join rules. + # has a user who could join the room. elif origin: - if await self._event_auth_handler.check_host_in_room(room_id, origin): + if await self._event_auth_handler.check_host_in_room( + room_id, origin + ) or await self._store.is_host_invited(room_id, origin): return True # Alternately, if the host has a user in any of the spaces specified @@ -490,18 +525,10 @@ class SpaceSummaryHandler: ): return True - # otherwise, check if the room is peekable - hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None) - if hist_vis_event_id: - hist_vis_ev = await self._store.get_event(hist_vis_event_id) - hist_vis = hist_vis_ev.content.get("history_visibility") - if hist_vis == HistoryVisibility.WORLD_READABLE: - return True - logger.info( - "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary", + "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", room_id, - requester, + requester or origin, ) return False @@ -535,6 +562,7 @@ class SpaceSummaryHandler: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], + "join_rules": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2796354a1f..4d82c4c26d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -703,13 +703,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: + return await self._check_host_room_membership(room_id, host, Membership.JOIN) + + @cached(max_entries=10000) + async def is_host_invited(self, room_id: str, host: str) -> bool: + return await self._check_host_room_membership(room_id, host, Membership.INVITE) + + async def _check_host_room_membership( + self, room_id: str, host: str, membership: str + ) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ SELECT state_key FROM current_state_events AS c INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' + WHERE m.membership = ? AND type = 'm.room.member' AND c.room_id = ? AND state_key LIKE ? @@ -722,7 +731,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): like_clause = "%:" + host rows = await self.db_pool.execute( - "is_host_joined", None, sql, room_id, like_clause + "is_host_joined", None, sql, membership, room_id, like_clause ) if not rows: diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index faed1f1a18..3f73ad7f94 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,8 +14,18 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock -from synapse.api.constants import EventContentFields, JoinRules, RoomTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -117,7 +127,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): """Add a child room to a space.""" self.helper.send_state( space_id, - event_type="m.space.child", + event_type=EventTypes.SpaceChild, body={"via": [self.hs.hostname]}, tok=token, state_key=room_id, @@ -155,29 +165,129 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # The user cannot see the space. self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - # Joining the room causes it to be visible. - self.helper.join(self.space, user2, tok=token2) + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - # The result should only have the space, but includes the link to the room. - self._assert_rooms(result, [self.space]) + self._assert_rooms(result, [self.space, self.room]) self._assert_events(result, [(self.space, self.room)]) - def test_world_readable(self): - """A world-readable room is visible to everyone.""" + # Make it not world-readable again and confirm it results in an error. self.helper.send_state( self.space, - event_type="m.room.history_visibility", - body={"history_visibility": "world_readable"}, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, tok=self.token, ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, [self.space, self.room]) + self._assert_events(result, [(self.space, self.room)]) + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, + tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, + ) + self._add_child(self.space, room_id, self.token) + return room_id + + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") - # The space should be visible, as well as the link to the room. + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.MSC3083_RESTRICTED, + room_version=RoomVersions.MSC3083.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.MSC3083_RESTRICTED, + room_version=RoomVersions.MSC3083.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space]) - self._assert_events(result, [(self.space, self.room)]) + + self._assert_rooms( + result, + [ + self.space, + self.room, + public_room, + knock_room, + invited_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, public_room), + (self.space, knock_room), + (self.space, not_invited_room), + (self.space, invited_room), + (self.space, restricted_room), + (self.space, restricted_accessible_room), + (self.space, world_readable_room), + (self.space, joined_room), + ], + ) def test_complex_space(self): """ @@ -186,7 +296,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Create an inaccessible room. user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") - room2 = self.helper.create_room_as(user2, tok=token2) + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) # This is a bit odd as "user" is adding a room they don't know about, but # it works for the tests. self._add_child(self.space, room2, self.token) @@ -292,16 +402,60 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): subspace = "#subspace:" + fed_hostname # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname restricted_room = "#restricted:" + fed_hostname restricted_accessible_room = "#restricted_accessible:" + fed_hostname world_readable_room = "#world_readable:" + fed_hostname joined_room = self.helper.create_room_as(self.user, tok=self.token) + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + event = make_event_from_dict( + { + "room_id": invited_room, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": "@remote:" + fed_hostname, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): # Note that these entries are brief, but should contain enough info. rooms = [ + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, { "room_id": restricted_room, "world_readable": False, @@ -364,6 +518,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.space, self.room, subspace, + public_room, + knock_room, + invited_room, restricted_accessible_room, world_readable_room, joined_room, @@ -374,6 +531,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): [ (self.space, self.room), (self.space, subspace), + (subspace, public_room), + (subspace, knock_room), + (subspace, not_invited_room), + (subspace, invited_room), (subspace, restricted_room), (subspace, restricted_accessible_room), (subspace, world_readable_room), -- cgit 1.5.1 From eb3beb8f12a5ee93e19eacf0f03c6bcde18999fe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Jul 2021 09:13:40 -0400 Subject: Add type hints and comments to event auth code. (#10393) --- changelog.d/10393.misc | 1 + mypy.ini | 1 + synapse/event_auth.py | 3 +++ tests/test_event_auth.py | 23 +++++++++++++---------- 4 files changed, 18 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10393.misc (limited to 'tests') diff --git a/changelog.d/10393.misc b/changelog.d/10393.misc new file mode 100644 index 0000000000..e80f16d607 --- /dev/null +++ b/changelog.d/10393.misc @@ -0,0 +1 @@ +Add type hints and comments to event auth code. diff --git a/mypy.ini b/mypy.ini index 72ce932d73..8717ae738e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,6 +83,7 @@ files = synapse/util/stringutils.py, synapse/visibility.py, tests/replication, + tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, tests/rest/client/v1/test_login.py, diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 89bcf81515..a3df6cfcc1 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -48,6 +48,9 @@ def check( room_version_obj: the version of the room event: the event being checked. auth_events: the existing room state. + do_sig_check: True if it should be verified that the sending server + signed the event. + do_size_check: True if the size of the event fields should be verified. Raises: AuthError if the checks fail diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 88888319cc..f73306ecc4 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,12 +13,13 @@ # limitations under the License. import unittest +from typing import Optional from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.types import get_domain_from_id +from synapse.events import EventBase, make_event_from_dict +from synapse.types import JsonDict, get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -432,7 +433,7 @@ class EventAuthTestCase(unittest.TestCase): TEST_ROOM_ID = "!test:room" -def _create_event(user_id): +def _create_event(user_id: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -444,7 +445,9 @@ def _create_event(user_id): ) -def _member_event(user_id, membership, sender=None): +def _member_event( + user_id: str, membership: str, sender: Optional[str] = None +) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -458,11 +461,11 @@ def _member_event(user_id, membership, sender=None): ) -def _join_event(user_id): +def _join_event(user_id: str) -> EventBase: return _member_event(user_id, "join") -def _power_levels_event(sender, content): +def _power_levels_event(sender: str, content: JsonDict) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -475,7 +478,7 @@ def _power_levels_event(sender, content): ) -def _alias_event(sender, **kwargs): +def _alias_event(sender: str, **kwargs) -> EventBase: data = { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -488,7 +491,7 @@ def _alias_event(sender, **kwargs): return make_event_from_dict(data) -def _random_state_event(sender): +def _random_state_event(sender: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -501,7 +504,7 @@ def _random_state_event(sender): ) -def _join_rules_event(sender, join_rule): +def _join_rules_event(sender: str, join_rule: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -519,7 +522,7 @@ def _join_rules_event(sender, join_rule): event_count = 0 -def _get_event_id(): +def _get_event_id() -> str: global event_count c = event_count event_count += 1 -- cgit 1.5.1 From c7603af1d06d65932c420ae76002b6ed94dbf23c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 15 Jul 2021 11:37:08 +0200 Subject: Allow providing credentials to `http_proxy` (#10360) --- changelog.d/10360.feature | 1 + synapse/http/proxyagent.py | 12 +++++++- tests/http/test_proxyagent.py | 65 ++++++++++++++++++++++++++++++++++--------- 3 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 changelog.d/10360.feature (limited to 'tests') diff --git a/changelog.d/10360.feature b/changelog.d/10360.feature new file mode 100644 index 0000000000..904221cb6d --- /dev/null +++ b/changelog.d/10360.feature @@ -0,0 +1 @@ +Allow providing credentials to `http_proxy`. \ No newline at end of file diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7dfae8b786..7a6a1717de 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -117,7 +117,8 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - # Parse credentials from https proxy connection string if present + # Parse credentials from http and https proxy connection string if present + self.http_proxy_creds, http_proxy = parse_username_password(http_proxy) self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) self.http_proxy_endpoint = _http_proxy_endpoint( @@ -189,6 +190,15 @@ class ProxyAgent(_AgentBase): and self.http_proxy_endpoint and not should_skip_proxy ): + # Determine whether we need to set Proxy-Authorization headers + if self.http_proxy_creds: + # Set a Proxy-Authorization header + if headers is None: + headers = Headers() + headers.addRawHeader( + b"Proxy-Authorization", + self.http_proxy_creds.as_proxy_authorization_value(), + ) # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: pool_key = ("http-proxy", self.http_proxy_endpoint) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index fefc8099c9..437113929a 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -205,6 +205,41 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) def test_http_request_via_proxy(self): + """ + Tests that requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, + ) + def test_http_request_via_proxy_with_auth(self): + """ + Tests that authenticated requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials="bob:pinkponies") + + @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) + def test_https_request_via_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + + def _do_http_request_via_proxy( + self, + auth_credentials: Optional[str] = None, + ): + """ + Tests that requests can be made through a proxy. + """ agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -229,6 +264,23 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(b"bob:pinkponies") + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) @@ -241,19 +293,6 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): - """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials=None) - - @patch.dict( - os.environ, - {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, - ) - def test_https_request_via_proxy_with_auth(self): - """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") - def _do_https_request_via_proxy( self, auth_credentials: Optional[str] = None, -- cgit 1.5.1 From ac5c221208ceb499cf8e9305b03efe1765ba48f6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 15 Jul 2021 11:52:56 +0100 Subject: Stagger send presence to remotes (#10398) This is to help with performance, where trying to connect to thousands of hosts at once can consume a lot of CPU (due to TLS etc). Co-authored-by: Brendan Abolivier --- changelog.d/10398.misc | 1 + synapse/federation/sender/__init__.py | 96 +++++++++++++++++++++- synapse/federation/sender/per_destination_queue.py | 16 +++- tests/events/test_presence_router.py | 8 ++ 4 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10398.misc (limited to 'tests') diff --git a/changelog.d/10398.misc b/changelog.d/10398.misc new file mode 100644 index 0000000000..326e54655a --- /dev/null +++ b/changelog.d/10398.misc @@ -0,0 +1 @@ +Stagger sending of presence update to remote servers, reducing CPU spikes caused by starting many connections to remote servers at once. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 0960f033bc..d980e0d986 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,9 +14,12 @@ import abc import logging +from collections import OrderedDict from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple +import attr from prometheus_client import Counter +from typing_extensions import Literal from twisted.internet import defer @@ -33,8 +36,12 @@ from synapse.metrics import ( event_processing_loop_room_count, events_processed_counter, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.util import Clock from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -137,6 +144,84 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() +@attr.s +class _PresenceQueue: + """A queue of destinations that need to be woken up due to new presence + updates. + + Staggers waking up of per destination queues to ensure that we don't attempt + to start TLS connections with many hosts all at once, leading to pinned CPU. + """ + + # The maximum duration in seconds between queuing up a destination and it + # being woken up. + _MAX_TIME_IN_QUEUE = 30.0 + + # The maximum duration in seconds between waking up consecutive destination + # queues. + _MAX_DELAY = 0.1 + + sender: "FederationSender" = attr.ib() + clock: Clock = attr.ib() + queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict) + processing: bool = attr.ib(default=False) + + def add_to_queue(self, destination: str) -> None: + """Add a destination to the queue to be woken up.""" + + self.queue[destination] = None + + if not self.processing: + self._handle() + + @wrap_as_background_process("_PresenceQueue.handle") + async def _handle(self) -> None: + """Background process to drain the queue.""" + + if not self.queue: + return + + assert not self.processing + self.processing = True + + try: + # We start with a delay that should drain the queue quickly enough that + # we process all destinations in the queue in _MAX_TIME_IN_QUEUE + # seconds. + # + # We also add an upper bound to the delay, to gracefully handle the + # case where the queue only has a few entries in it. + current_sleep_seconds = min( + self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue) + ) + + while self.queue: + destination, _ = self.queue.popitem(last=False) + + queue = self.sender._get_per_destination_queue(destination) + + if not queue._new_data_to_send: + # The per destination queue has already been woken up. + continue + + queue.attempt_new_transaction() + + await self.clock.sleep(current_sleep_seconds) + + if not self.queue: + break + + # More destinations may have been added to the queue, so we may + # need to reduce the delay to ensure everything gets processed + # within _MAX_TIME_IN_QUEUE seconds. + current_sleep_seconds = min( + current_sleep_seconds, self._MAX_TIME_IN_QUEUE / len(self.queue) + ) + + finally: + self.processing = False + + class FederationSender(AbstractFederationSender): def __init__(self, hs: "HomeServer"): self.hs = hs @@ -208,6 +293,8 @@ class FederationSender(AbstractFederationSender): self._external_cache = hs.get_external_cache() + self._presence_queue = _PresenceQueue(self, self.clock) + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -517,7 +604,12 @@ class FederationSender(AbstractFederationSender): self._instance_name, destination ): continue - self._get_per_destination_queue(destination).send_presence(states) + + self._get_per_destination_queue(destination).send_presence( + states, start_loop=False + ) + + self._presence_queue.add_to_queue(destination) def build_and_send_edu( self, diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d06a3aff19..c11d1f6d31 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -171,14 +171,24 @@ class PerDestinationQueue: self.attempt_new_transaction() - def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if necessary. + def send_presence( + self, states: Iterable[UserPresenceState], start_loop: bool = True + ) -> None: + """Add presence updates to the queue. + + Args: + states: Presence updates to send + start_loop: Whether to start the transmission loop if not already + running. Args: states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) - self.attempt_new_transaction() + self._new_data_to_send = True + + if start_loop: + self.attempt_new_transaction() def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index c4ad33194d..3f41e99950 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -285,6 +285,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) self.assertEqual(len(presence_updates), 3) + # We stagger sending of presence, so we need to wait a bit for them to + # get sent out. + self.reactor.advance(60) + # Test that sending to a remote user works remote_user_id = "@far_away_person:island" @@ -301,6 +305,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): self.module_api.send_local_online_presence_to([remote_user_id]) ) + # We stagger sending of presence, so we need to wait a bit for them to + # get sent out. + self.reactor.advance(60) + # Check that the expected presence updates were sent # We explicitly compare using sets as we expect that calling # module_api.send_local_online_presence_to will create a presence -- cgit 1.5.1 From 6a6006825067827b533b9c2b35c5a1d6a796e27c Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Thu, 15 Jul 2021 13:51:27 +0100 Subject: Add tests to characterise the current behaviour of R30 phone-home metrics (#10315) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/10315.misc | 1 + tests/app/test_phone_stats_home.py | 153 +++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 changelog.d/10315.misc create mode 100644 tests/app/test_phone_stats_home.py (limited to 'tests') diff --git a/changelog.d/10315.misc b/changelog.d/10315.misc new file mode 100644 index 0000000000..2c78644e20 --- /dev/null +++ b/changelog.d/10315.misc @@ -0,0 +1 @@ +Add tests to characterise the current behaviour of R30 phone-home metrics. diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py new file mode 100644 index 0000000000..2da6ba4dde --- /dev/null +++ b/tests/app/test_phone_stats_home.py @@ -0,0 +1,153 @@ +import synapse +from synapse.rest.client.v1 import login, room + +from tests import unittest +from tests.unittest import HomeserverTestCase + +ONE_DAY_IN_SECONDS = 86400 + + +class PhoneHomeTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + # Override the retention time for the user_ips table because otherwise it + # gets pruned too aggressively for our R30 test. + @unittest.override_config({"user_ips_max_age": "365d"}) + def test_r30_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the R30 results do not count that user. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Advance 30 days (+ 1 second, because strict inequality causes issues if we are + # bang on 30 days later). + self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait some time for _update_client_ips_batch to get + # called and update the user_ips table. + self.reactor.advance(2 * 60 * 60) + + # *Now* the user is counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance 29 days. The user has now not posted for 29 days. + self.reactor.advance(29 * ONE_DAY_IN_SECONDS) + + # The user is still counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance another day. The user has now not posted for 30 days. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # The user is now no longer counted in R30. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + def test_r30_minimum_usage_using_default_config(self): + """ + Tests the minimum amount of interaction necessary for the R30 metric + to consider a user 'retained'. + + N.B. This test does not override the `user_ips_max_age` config setting, + which defaults to 28 days. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the R30 results do not count that user. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Advance 30 days (+ 1 second, because strict inequality causes issues if we are + # bang on 30 days later). + self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait some time for _update_client_ips_batch to get + # called and update the user_ips table. + self.reactor.advance(2 * 60 * 60) + + # *Now* the user is counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance 27 days. The user has now not posted for 27 days. + self.reactor.advance(27 * ONE_DAY_IN_SECONDS) + + # The user is still counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance another day. The user has now not posted for 28 days. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # The user is now no longer counted in R30. + # (This is because the user_ips table has been pruned, which by default + # only preserves the last 28 days of entries.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + def test_r30_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30 statistic, even if they post every day + during that time! + """ + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "I'm still here", tok=access_token) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "Still here!", tok=access_token) + + # *Now* the user appears in R30. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) -- cgit 1.5.1 From 36dc15412de9fc1bb2ba955c8b6f2da20d2ca20f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 16 Jul 2021 18:11:53 +0200 Subject: Add a module type for account validity (#9884) This adds an API for third-party plugin modules to implement account validity, so they can provide this feature instead of Synapse. The module implementing the current behaviour for this feature can be found at https://github.com/matrix-org/synapse-email-account-validity. To allow for a smooth transition between the current feature and the new module, hooks have been added to the existing account validity endpoints to allow their behaviours to be overridden by a module. --- changelog.d/9884.feature | 1 + docs/modules.md | 47 ++++- docs/sample_config.yaml | 85 --------- synapse/api/auth.py | 17 +- synapse/config/account_validity.py | 102 ++--------- synapse/handlers/account_validity.py | 128 ++++++++++++- synapse/handlers/register.py | 5 + synapse/module_api/__init__.py | 219 +++++++++++++++++++++-- synapse/module_api/errors.py | 6 +- synapse/push/pusherpool.py | 24 +-- synapse/rest/admin/users.py | 24 ++- synapse/rest/client/v2_alpha/account_validity.py | 7 +- tests/test_state.py | 1 + 13 files changed, 438 insertions(+), 228 deletions(-) create mode 100644 changelog.d/9884.feature (limited to 'tests') diff --git a/changelog.d/9884.feature b/changelog.d/9884.feature new file mode 100644 index 0000000000..525fd2f93c --- /dev/null +++ b/changelog.d/9884.feature @@ -0,0 +1 @@ +Add a module type for the account validity feature. diff --git a/docs/modules.md b/docs/modules.md index bec1c06d15..c4cb7018f7 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following API method: ```python -def ModuleApi.register_web_resource(path: str, resource: IResource) +def ModuleApi.register_web_resource(path: str, resource: IResource) -> None ``` The path is the full absolute path to register the resource at. For example, if you @@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c and is under no obligation to implement all callbacks from the categories it registers callbacks for. +Modules can register callbacks using one of the module API's `register_[...]_callbacks` +methods. The callback functions are passed to these methods as keyword arguments, with +the callback name as the argument name and the function as its value. This is demonstrated +in the example below. A `register_[...]_callbacks` method exists for each module type +documented in this section. + #### Spam checker callbacks -To register one of the callbacks described in this section, a module needs to use the -module API's `register_spam_checker_callbacks` method. The callback functions are passed -to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the -argument name and the function as its value. This is demonstrated in the example below. +Spam checker callbacks allow module developers to implement spam mitigation actions for +Synapse instances. Spam checker callbacks can be registered using the module API's +`register_spam_checker_callbacks` method. The available spam checker callbacks are: @@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (i.e. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). ```python async def user_may_create_room(user: str) -> bool @@ -188,6 +193,36 @@ async def check_media_file_for_spam( Called when storing a local or remote file. The module must return a boolean indicating whether the given file can be stored in the homeserver's media store. +#### Account validity callbacks + +Account validity callbacks allow module developers to add extra steps to verify the +validity on an account, i.e. see if a user can be granted access to their account on the +Synapse instance. Account validity callbacks can be registered using the module API's +`register_account_validity_callbacks` method. + +The available account validity callbacks are: + +```python +async def is_user_expired(user: str) -> Optional[bool] +``` + +Called when processing any authenticated request (except for logout requests). The module +can return a `bool` to indicate whether the user has expired and should be locked out of +their account, or `None` if the module wasn't able to figure it out. The user is +represented by their Matrix user ID (e.g. `@alice:example.com`). + +If the module returns `True`, the current request will be denied with the error code +`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't +invalidate the user's access token. + +```python +async def on_user_registration(user: str) -> None +``` + +Called after successfully registering a user, in case the module needs to perform extra +operations to keep track of them. (e.g. add them to a database table). The user is +represented by their Matrix user ID. + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a45732a246..f4845a5841 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1310,91 +1310,6 @@ account_threepid_delegates: #auto_join_rooms_for_guests: false -## Account Validity ## - -# Optional account validity configuration. This allows for accounts to be denied -# any request after a given period. -# -# Once this feature is enabled, Synapse will look for registered users without an -# expiration date at startup and will add one to every account it found using the -# current settings at that time. -# This means that, if a validity period is set, and Synapse is restarted (it will -# then derive an expiration date from the current validity period), and some time -# after that the validity period changes and Synapse is restarted, the users' -# expiration dates won't be updated unless their account is manually renewed. This -# date will be randomly selected within a range [now + period - d ; now + period], -# where d is equal to 10% of the validity period. -# -account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - - ## Metrics ### # Enable collection and rendering of performance metrics diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 8916e6fa2f..05699714ee 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -62,6 +62,7 @@ class Auth: self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() + self._account_validity_handler = hs.get_account_validity_handler() self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" @@ -69,9 +70,6 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users @@ -187,12 +185,17 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity_enabled and not allow_expired: - if await self.store.is_account_expired( - user_info.user_id, self.clock.time_msec() + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + user_info.user_id ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired raise AuthError( - 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, ) device_id = user_info.device_id diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 957de7f3a6..6be4eafe55 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -18,6 +18,21 @@ class AccountValidityConfig(Config): section = "account_validity" def read_config(self, config, **kwargs): + """Parses the old account validity config. The config format looks like this: + + account_validity: + enabled: true + period: 6w + renew_at: 1w + renew_email_subject: "Renew your %(app)s account" + template_dir: "res/templates" + account_renewed_html_path: "account_renewed.html" + invalid_token_html_path: "invalid_token.html" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ account_validity_config = config.get("account_validity") or {} self.account_validity_enabled = account_validity_config.get("enabled", False) self.account_validity_renew_by_email_enabled = ( @@ -75,90 +90,3 @@ class AccountValidityConfig(Config): ], account_validity_template_dir, ) - - def generate_config_section(self, **kwargs): - return """\ - ## Account Validity ## - - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - """ diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d752cf34f0..078accd634 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -15,9 +15,11 @@ import email.mime.multipart import email.utils import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from synapse.api.errors import StoreError, SynapseError +from twisted.web.http import Request + +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils @@ -27,6 +29,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Types for callbacks to be registered via the module api +IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] +ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] +# Temporary hooks to allow for a transition from `/_matrix/client` endpoints +# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`. +ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] +ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]] +ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable] + class AccountValidityHandler: def __init__(self, hs: "HomeServer"): @@ -70,6 +81,99 @@ class AccountValidityHandler: if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] + self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] + self._on_legacy_send_mail_callback: Optional[ + ON_LEGACY_SEND_MAIL_CALLBACK + ] = None + self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None + + # The legacy admin requests callback isn't a protected attribute because we need + # to access it from the admin servlet, which is outside of this handler. + self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None + + def register_account_validity_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ): + """Register callbacks from module for each hook.""" + if is_user_expired is not None: + self._is_user_expired_callbacks.append(is_user_expired) + + if on_user_registration is not None: + self._on_user_registration_callbacks.append(on_user_registration) + + # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and + # an admin one). As part of moving the feature into a module, we need to change + # the path from /_matrix/client/unstable/account_validity/... to + # /_synapse/client/account_validity, because: + # + # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix + # * the way we register servlets means that modules can't register resources + # under /_matrix/client + # + # We need to allow for a transition period between the old and new endpoints + # in order to allow for clients to update (and for emails to be processed). + # + # Once the email-account-validity module is loaded, it will take control of account + # validity by moving the rows from our `account_validity` table into its own table. + # + # Therefore, we need to allow modules (in practice just the one implementing the + # email-based account validity) to temporarily hook into the legacy endpoints so we + # can route the traffic coming into the old endpoints into the module, which is + # why we have the following three temporary hooks. + if on_legacy_send_mail is not None: + if self._on_legacy_send_mail_callback is not None: + raise RuntimeError("Tried to register on_legacy_send_mail twice") + + self._on_legacy_send_mail_callback = on_legacy_send_mail + + if on_legacy_renew is not None: + if self._on_legacy_renew_callback is not None: + raise RuntimeError("Tried to register on_legacy_renew twice") + + self._on_legacy_renew_callback = on_legacy_renew + + if on_legacy_admin_request is not None: + if self.on_legacy_admin_request_callback is not None: + raise RuntimeError("Tried to register on_legacy_admin_request twice") + + self.on_legacy_admin_request_callback = on_legacy_admin_request + + async def is_user_expired(self, user_id: str) -> bool: + """Checks if a user has expired against third-party modules. + + Args: + user_id: The user to check the expiry of. + + Returns: + Whether the user has expired. + """ + for callback in self._is_user_expired_callbacks: + expired = await callback(user_id) + if expired is not None: + return expired + + if self._account_validity_enabled: + # If no module could determine whether the user has expired and the legacy + # configuration is enabled, fall back to it. + return await self.store.is_account_expired(user_id, self.clock.time_msec()) + + return False + + async def on_user_registration(self, user_id: str): + """Tell third-party modules about a user's registration. + + Args: + user_id: The ID of the newly registered user. + """ + for callback in self._on_user_registration_callbacks: + await callback(user_id) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time @@ -95,6 +199,17 @@ class AccountValidityHandler: Raises: SynapseError if the user is not set to renew. """ + # If a module supports sending a renewal email from here, do that, otherwise do + # the legacy dance. + if self._on_legacy_send_mail_callback is not None: + await self._on_legacy_send_mail_callback(user_id) + return + + if not self._account_validity_renew_by_email_enabled: + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) # If this user isn't set to be expired, raise an error. @@ -209,6 +324,10 @@ class AccountValidityHandler: token is considered stale. A token is stale if the 'token_used_ts_ms' db column is non-null. + This method exists to support handling the legacy account validity /renew + endpoint. If a module implements the on_legacy_renew callback, then this process + is delegated to the module instead. + Args: renewal_token: Token sent with the renewal request. Returns: @@ -218,6 +337,11 @@ class AccountValidityHandler: * An int representing the user's expiry timestamp as milliseconds since the epoch, or 0 if the token was invalid. """ + # If a module supports triggering a renew from here, do that, otherwise do the + # legacy dance. + if self._on_legacy_renew_callback is not None: + return await self._on_legacy_renew_callback(renewal_token) + try: ( user_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 26ef016179..056fe5e89f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -77,6 +77,7 @@ class RegistrationHandler(BaseHandler): self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() + self._account_validity_handler = hs.get_account_validity_handler() self._server_notices_mxid = hs.config.server_notices_mxid self._server_name = hs.hostname @@ -700,6 +701,10 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + # Only call the account validity module(s) on the main process, to avoid + # repeating e.g. database writes on all of the workers. + await self._account_validity_handler.on_user_registration(user_id) + async def register_device( self, user_id: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 308f045700..f3c78089b7 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -12,18 +12,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import email.utils import logging -from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +import jinja2 from twisted.internet import defer from twisted.web.resource import IResource from synapse.events import EventBase from synapse.http.client import SimpleHttpClient +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util import Clock +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi are loaded into Synapse. """ -__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"] +__all__ = [ + "errors", + "make_deferred_yieldable", + "parse_json_object_from_request", + "respond_with_html", + "run_in_background", + "cached", + "UserID", + "DatabasePool", + "LoggingTransaction", + "DirectServeHtmlResource", + "DirectServeJsonResource", + "ModuleApi", +] logger = logging.getLogger(__name__) @@ -52,12 +89,27 @@ class ModuleApi: self._server_name = hs.hostname self._presence_stream = hs.get_event_sources().sources["presence"] self._state = hs.get_state_handler() + self._clock = hs.get_clock() # type: Clock + self._send_email_handler = hs.get_send_email_handler() + + try: + app_name = self._hs.config.email_app_name + + self._from_string = self._hs.config.email_notif_from % {"app": app_name} + except (KeyError, TypeError): + # If substitution failed (which can happen if the string contains + # placeholders other than just "app", or if the type of the placeholder is + # not a string), fall back to the bare strings. + self._from_string = self._hs.config.email_notif_from + + self._raw_from = email.utils.parseaddr(self._from_string)[1] # We expose these as properties below in order to attach a helpful docstring. self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) self._spam_checker = hs.get_spam_checker() + self._account_validity_handler = hs.get_account_validity_handler() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -67,6 +119,11 @@ class ModuleApi: """Registers callbacks for spam checking capabilities.""" return self._spam_checker.register_callbacks + @property + def register_account_validity_callbacks(self): + """Registers callbacks for account validity capabilities.""" + return self._account_validity_handler.register_account_validity_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. @@ -101,22 +158,56 @@ class ModuleApi: """ return self._public_room_list_manager - def get_user_by_req(self, req, allow_guest=False): + @property + def public_baseurl(self) -> str: + """The configured public base URL for this homeserver.""" + return self._hs.config.public_baseurl + + @property + def email_app_name(self) -> str: + """The application name configured in the homeserver's configuration.""" + return self._hs.config.email.email_app_name + + async def get_user_by_req( + self, + req: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + ) -> Requester: """Check the access_token provided for a request Args: - req (twisted.web.server.Request): Incoming HTTP request - allow_guest (bool): True if guest users should be allowed. If this + req: Incoming HTTP request + allow_guest: True if guest users should be allowed. If this is False, and the access token is for a guest user, an AuthError will be thrown + allow_expired: True if expired users should be allowed. If this + is False, and the access token is for an expired user, an + AuthError will be thrown + Returns: - twisted.internet.defer.Deferred[synapse.types.Requester]: - the requester for this request + The requester for this request + Raises: - synapse.api.errors.AuthError: if no user by that token exists, + InvalidClientCredentialsError: if no user by that token exists, or the token is invalid. """ - return self._auth.get_user_by_req(req, allow_guest) + return await self._auth.get_user_by_req( + req, + allow_guest, + allow_expired=allow_expired, + ) + + async def is_user_admin(self, user_id: str) -> bool: + """Checks if a user is a server admin. + + Args: + user_id: The Matrix ID of the user to check. + + Returns: + True if the user is a server admin, False otherwise. + """ + return await self._store.is_server_admin(UserID.from_string(user_id)) def get_qualified_user_id(self, username): """Qualify a user id, if necessary @@ -134,6 +225,32 @@ class ModuleApi: return username return UserID(username, self._hs.hostname).to_string() + async def get_profile_for_user(self, localpart: str) -> ProfileInfo: + """Look up the profile info for the user with the given localpart. + + Args: + localpart: The localpart to look up profile information for. + + Returns: + The profile information (i.e. display name and avatar URL). + """ + return await self._store.get_profileinfo(localpart) + + async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: + """Look up the threepids (email addresses and phone numbers) associated with the + given Matrix user ID. + + Args: + user_id: The Matrix user ID to look up threepids for. + + Returns: + A list of threepids, each threepid being represented by a dictionary + containing a "medium" key which value is "email" for email addresses and + "msisdn" for phone numbers, and an "address" key which value is the + threepid's address. + """ + return await self._store.user_get_threepids(user_id) + def check_user_exists(self, user_id): """Check if user exists. @@ -464,6 +581,88 @@ class ModuleApi: presence_events, destination ) + def looping_background_call( + self, + f: Callable, + msec: float, + *args, + desc: Optional[str] = None, + **kwargs, + ): + """Wraps a function as a background process and calls it repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f: The function to call repeatedly. f can be either synchronous or + asynchronous, and must follow Synapse's logcontext rules. + More info about logcontexts is available at + https://matrix-org.github.io/synapse/latest/log_contexts.html + msec: How long to wait between calls in milliseconds. + *args: Positional arguments to pass to function. + desc: The background task's description. Default to the function's name. + **kwargs: Key arguments to pass to function. + """ + if desc is None: + desc = f.__name__ + + if self._hs.config.run_background_tasks: + self._clock.looping_call( + run_as_background_process, + msec, + desc, + f, + *args, + **kwargs, + ) + else: + logger.warning( + "Not running looping call %s as the configuration forbids it", + f, + ) + + async def send_mail( + self, + recipient: str, + subject: str, + html: str, + text: str, + ): + """Send an email on behalf of the homeserver. + + Args: + recipient: The email address for the recipient. + subject: The email's subject. + html: The email's HTML content. + text: The email's text content. + """ + await self._send_email_handler.send_email( + email_address=recipient, + subject=subject, + app_name=self.email_app_name, + html=html, + text=text, + ) + + def read_templates( + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Read and load the content of the template files at the given location. + By default, Synapse will look for these templates in its configured template + directory, but another directory to search in can be provided. + + Args: + filenames: The name of the template files to look for. + custom_template_directory: An additional directory to look for the files in. + + Returns: + A list containing the loaded templates, with the orders matching the one of + the filenames parameter. + """ + return self._hs.config.read_templates(filenames, custom_template_directory) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 02bbb0be39..98ea911a81 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -14,5 +14,9 @@ """Exception types which are exposed as part of the stable module API""" -from synapse.api.errors import RedirectException, SynapseError # noqa: F401 +from synapse.api.errors import ( # noqa: F401 + InvalidClientCredentialsError, + RedirectException, + SynapseError, +) from synapse.config._base import ConfigError # noqa: F401 diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2519ad76db..85621f33ef 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,10 +62,6 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) - # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() @@ -89,6 +85,8 @@ class PusherPool: # map from user id to app_id:pushkey to pusher self.pushers: Dict[str, Dict[str, Pusher]] = {} + self._account_validity_handler = hs.get_account_validity_handler() + def start(self) -> None: """Starts the pushers off in a background process.""" if not self._should_start_pushers: @@ -238,12 +236,9 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): @@ -268,12 +263,9 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 7d75564758..06e6ccee42 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - body = parse_json_object_from_request(request) + if self.account_activity_handler.on_legacy_admin_request_callback: + expiration_ts = await ( + self.account_activity_handler.on_legacy_admin_request_callback(request) + ) + else: + body = parse_json_object_from_request(request) - if "user_id" not in body: - raise SynapseError(400, "Missing property 'user_id' in the request body") + if "user_id" not in body: + raise SynapseError( + 400, + "Missing property 'user_id' in the request body", + ) - expiration_ts = await self.account_activity_handler.renew_account_for_user( - body["user_id"], - body.get("expiration_ts"), - not body.get("enable_renewal_emails", True), - ) + expiration_ts = await self.account_activity_handler.renew_account_for_user( + body["user_id"], + body.get("expiration_ts"), + not body.get("enable_renewal_emails", True), + ) res = {"expiration_ts": expiration_ts} return 200, res diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2d1ad3d3fb..3ebe401861 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -14,7 +14,7 @@ import logging -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet @@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet): ) async def on_POST(self, request): - if not self.account_validity_renew_by_email_enabled: - raise AuthError( - 403, "Account renewal via email is disabled on this server." - ) - requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) diff --git a/tests/test_state.py b/tests/test_state.py index 780eba823c..e5488df1ac 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase): "get_state_handler", "get_clock", "get_state_resolution_handler", + "get_account_validity_handler", "hostname", ] ) -- cgit 1.5.1 From 4e340412c020f685cb402a735b983f6e332e206b Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 19 Jul 2021 16:11:34 +0100 Subject: Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric (#10332) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/10332.feature | 1 + synapse/app/phone_stats_home.py | 4 + synapse/storage/databases/main/metrics.py | 129 ++++++++++++++++ tests/app/test_phone_stats_home.py | 242 ++++++++++++++++++++++++++++++ tests/rest/client/v1/utils.py | 30 +++- tests/unittest.py | 15 +- 6 files changed, 416 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10332.feature (limited to 'tests') diff --git a/changelog.d/10332.feature b/changelog.d/10332.feature new file mode 100644 index 0000000000..091947ff22 --- /dev/null +++ b/changelog.d/10332.feature @@ -0,0 +1 @@ +Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 8f86cecb76..7904c246df 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -107,6 +107,10 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): for name, count in r30_results.items(): stats["r30_users_" + name] = count + r30v2_results = await hs.get_datastore().count_r30_users() + for name, count in r30v2_results.items(): + stats["r30v2_users_" + name] = count + stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index e3a544d9b2..dc0bbc56ac 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -316,6 +316,135 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + async def count_r30v2_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as users that: + - Appear more than once in the past 60 days + - Have more than 30 days between the most and least recent appearances that + occurred in the past 60 days. + + (This is the second version of this metric, hence R30'v2') + + Returns: + A mapping from client type to the number of 30-day retained users for that client. + + The dict keys are: + - "all" (a combined number of users across any and all clients) + - "android" (Element Android) + - "ios" (Element iOS) + - "electron" (Element Desktop) + - "web" (any web application -- it's not possible to distinguish Element Web here) + """ + + def _count_r30v2_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs + one_day_from_now_in_secs = now + 86400 + + # This is the 'per-platform' count. + sql = """ + SELECT + client_type, + count(client_type) + FROM + ( + SELECT + user_id, + CASE + WHEN + LOWER(user_agent) LIKE '%%riot%%' OR + LOWER(user_agent) LIKE '%%element%%' + THEN CASE + WHEN + LOWER(user_agent) LIKE '%%electron%%' + THEN 'electron' + WHEN + LOWER(user_agent) LIKE '%%android%%' + THEN 'android' + WHEN + LOWER(user_agent) LIKE '%%ios%%' + THEN 'ios' + ELSE 'unknown' + END + WHEN + LOWER(user_agent) LIKE '%%mozilla%%' OR + LOWER(user_agent) LIKE '%%gecko%%' + THEN 'web' + ELSE 'unknown' + END as client_type + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id, + client_type + HAVING + max(timestamp) - min(timestamp) > ? + ) AS temp + GROUP BY + client_type + ; + """ + + # We initialise all the client types to zero, so we get an explicit + # zero if they don't appear in the query results + results = {"ios": 0, "android": 0, "web": 0, "electron": 0} + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + + for row in txn: + if row[0] == "unknown": + continue + results[row[0]] = row[1] + + # This is the 'all users' count. + sql = """ + SELECT COUNT(*) FROM ( + SELECT + 1 + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id + HAVING + max(timestamp) - min(timestamp) > ? + ) AS r30_users + """ + + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + row = txn.fetchone() + if row is None: + results["all"] = 0 + else: + results["all"] = row[0] + + return results + + return await self.db_pool.runInteraction( + "count_r30v2_users", _count_r30v2_users + ) + def _get_start_of_day(self): """ Returns millisecond unixtime for start of UTC day. diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 2da6ba4dde..5527e278db 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,9 +1,11 @@ import synapse +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.rest.client.v1 import login, room from tests import unittest from tests.unittest import HomeserverTestCase +FIVE_MINUTES_IN_SECONDS = 300 ONE_DAY_IN_SECONDS = 86400 @@ -151,3 +153,243 @@ class PhoneHomeTestCase(HomeserverTestCase): # *Now* the user appears in R30. r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + +class PhoneHomeR30V2TestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def _advance_to(self, desired_time_secs: float): + now = self.hs.get_clock().time() + assert now < desired_time_secs + self.reactor.advance(desired_time_secs - now) + + def make_homeserver(self, reactor, clock): + hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) + + # We don't want our tests to actually report statistics, so check + # that it's not enabled + assert not hs.config.report_stats + + # This starts the needed data collection that we rely on to calculate + # R30v2 metrics. + start_phone_stats_home(hs) + return hs + + def test_r30v2_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30v2 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + first_post_at = self.hs.get_clock().time() + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the R30 results do not count that user. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance 31 days. + # (R30v2 includes users with **more** than 30 days between the two visits, + # and user_daily_visits records the timestamp as the start of the day.) + self.reactor.advance(31 * ONE_DAY_IN_SECONDS) + # Also advance 5 minutes to let another user_daily_visits update occur + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait a few minutes for the user_daily_visits table to + # be updated by a background process. + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user is counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance to JUST under 60 days after the user's first post + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5) + + # Check the user is still counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance into the next day. The user's first activity is now more than 60 days old. + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5) + + # Check the user is now no longer counted in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30v2 statistic, even if they post every day + during that time! + """ + + # set a custom user-agent to impersonate Element/Android. + headers = ( + ( + "User-Agent", + "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS) + self.helper.send( + room_id, "I'm still here", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # advance yet another day with more activity + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send( + room_id, "Still here!", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user appears in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_returning_dormant_users_not_counted(self): + """ + Tests that dormant users (users inactive for a long time) do not + contribute to R30v2 when they return for just a single day. + This is a key difference between R30 and R30v2. + """ + + # set a custom user-agent to impersonate Element/iOS. + headers = ( + ( + "User-Agent", + "Riot/1.4 (iPhone; iOS 13; Scale/4.00)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # the user goes inactive for 2 months + self.reactor.advance(60 * ONE_DAY_IN_SECONDS) + + # the user returns for one day, perhaps just to check out a new feature + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check that the user does not contribute to R30v2, even though it's been + # more than 30 days since registration. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Check that this is a situation where old R30 differs: + # old R30 DOES count this as 'retained'. + r30_results = self.get_success(store.count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "ios": 1}) + + # Now we want to check that the user will still be able to appear in + # R30v2 as long as the user performs some other activity between + # 30 and 60 days later. + self.reactor.advance(32 * ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # (give time for tables to update) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Check the user now satisfies the requirements to appear in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 59.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS) + + # Check the user still appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 60.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # Check the user no longer appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 69798e95c3..fc2d35596e 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -19,7 +19,7 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Mapping, MutableMapping, Optional +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from unittest.mock import patch import attr @@ -53,6 +53,9 @@ class RestHelper: tok: str = None, expect_code: int = 200, extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> str: """ Create a room. @@ -87,6 +90,7 @@ class RestHelper: "POST", path, json.dumps(content).encode("utf8"), + custom_headers=custom_headers, ) assert channel.result["code"] == b"%d" % expect_code, channel.result @@ -175,14 +179,30 @@ class RestHelper: self.auth_user_id = temp_id - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): if body is None: body = "body_text_here" content = {"msgtype": "m.text", "body": body} return self.send_event( - room_id, "m.room.message", content, txn_id, tok, expect_code + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, ) def send_event( @@ -193,6 +213,9 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -207,6 +230,7 @@ class RestHelper: "PUT", path, json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, ) assert ( diff --git a/tests/unittest.py b/tests/unittest.py index c6d9064423..3eec9c4d5b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase): user_id = channel.json_body["user_id"] return user_id - def login(self, username, password, device_id=None): + def login( + self, + username, + password, + device_id=None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): """ Log in a user, and get an access token. Requires the Login API be registered. @@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase): body["device_id"] = device_id channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") + "POST", + "/_matrix/client/r0/login", + json.dumps(body).encode("utf8"), + custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1