From 6c02cca95f8136010062b6af0fa36a2906a96a6b Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 1 Jul 2021 11:26:24 +0200 Subject: Add SSO `external_ids` to Query User Account admin API (#10261) Related to #10251 --- tests/rest/admin/test_user.py | 224 ++++++++++++++++++++++++++---------------- 1 file changed, 140 insertions(+), 84 deletions(-) (limited to 'tests') diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index d599a4c984..a34d051734 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1150,7 +1150,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1160,7 +1160,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): @@ -1177,6 +1177,58 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) + self._check_fields(channel.json_body) + + def test_get_user_with_sso(self): + """ + Test get a user with SSO details. + """ + self.get_success( + self.store.record_user_external_id( + "auth_provider1", "external_id1", self.other_user + ) + ) + self.get_success( + self.store.record_user_external_id( + "auth_provider2", "external_id2", self.other_user + ) + ) + + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][1]["external_id"] + ) + self.assertEqual( + "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + ) + self._check_fields(channel.json_body) + def test_create_server_admin(self): """ Check that a new admin user is created successfully. @@ -1184,30 +1236,29 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user (server admin) - body = json.dumps( - { - "password": "abc123", - "admin": True, - "displayname": "Bob's name", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": "mxc://fibble/wibble", - } - ) + body = { + "password": "abc123", + "admin": True, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1216,7 +1267,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1225,6 +1276,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["is_guest"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) def test_create_user(self): """ @@ -1233,30 +1285,29 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "admin": False, - "displayname": "Bob's name", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": "mxc://fibble/wibble", - } - ) + body = { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1265,7 +1316,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1275,6 +1326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1311,16 +1363,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123", "admin": False}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "admin": False}, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1350,17 +1400,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123", "admin": False}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "admin": False}, ) # Admin user is not blocked by mau anymore - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1382,21 +1430,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - ) + body = { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1426,21 +1472,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - ) + body = { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1457,16 +1501,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Change password - body = json.dumps({"password": "hahaha"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "hahaha"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self._check_fields(channel.json_body) def test_set_displayname(self): """ @@ -1474,16 +1517,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Modify user - body = json.dumps({"displayname": "foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "foobar"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1494,7 +1535,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1504,18 +1545,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Delete old and add new threepid to user - body = json.dumps( - {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} - ) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1527,7 +1564,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1552,7 +1589,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1567,7 +1604,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1583,7 +1620,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1610,7 +1647,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -1626,7 +1663,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -1650,7 +1687,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -1659,7 +1696,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -1681,7 +1718,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -1691,7 +1728,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1713,7 +1750,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -1723,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1742,7 +1779,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -1753,7 +1790,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -1772,7 +1809,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -1783,7 +1820,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -1796,7 +1833,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -1805,7 +1842,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -1830,7 +1867,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -1838,6 +1875,25 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIsNone(self.get_success(d)) self._is_erased(user_id, True) + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + + Args: + content: Content dictionary to check + """ + self.assertIn("displayname", content) + self.assertIn("threepids", content) + self.assertIn("avatar_url", content) + self.assertIn("admin", content) + self.assertIn("deactivated", content) + self.assertIn("shadow_banned", content) + self.assertIn("password_hash", content) + self.assertIn("creation_ts", content) + self.assertIn("appservice_id", content) + self.assertIn("consent_server_notice_sent", content) + self.assertIn("consent_version", content) + self.assertIn("external_ids", content) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 8d609435c0053fc4decbc3f9c3603e728912749c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 1 Jul 2021 14:25:37 -0400 Subject: Move methods involving event authentication to EventAuthHandler. (#10268) Instead of mixing them with user authentication methods. --- changelog.d/10268.misc | 1 + synapse/api/auth.py | 75 +------------------------------- synapse/events/builder.py | 12 ++--- synapse/federation/federation_server.py | 6 +-- synapse/handlers/event_auth.py | 62 +++++++++++++++++++++++++- synapse/handlers/federation.py | 36 +++++++++------ synapse/handlers/message.py | 9 ++-- synapse/handlers/room.py | 3 +- synapse/handlers/space_summary.py | 6 ++- synapse/push/bulk_push_rule_evaluator.py | 4 +- tests/handlers/test_presence.py | 4 +- 11 files changed, 112 insertions(+), 106 deletions(-) create mode 100644 changelog.d/10268.misc (limited to 'tests') diff --git a/changelog.d/10268.misc b/changelog.d/10268.misc new file mode 100644 index 0000000000..9e3f60c72f --- /dev/null +++ b/changelog.d/10268.misc @@ -0,0 +1 @@ +Move event authentication methods from `Auth` to `EventAuthHandler`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f8b068e563..307f5f9a94 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple import pymacaroons from netaddr import IPAddress @@ -28,10 +28,8 @@ from synapse.api.errors import ( InvalidClientTokenError, MissingClientTokenError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.appservice import ApplicationService from synapse.events import EventBase -from synapse.events.builder import EventBuilder from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing @@ -39,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry -from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,15 +44,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -AuthEventTypes = ( - EventTypes.Create, - EventTypes.Member, - EventTypes.PowerLevels, - EventTypes.JoinRules, - EventTypes.RoomHistoryVisibility, - EventTypes.ThirdPartyInvite, -) - # guests always get this device id. GUEST_DEVICE_ID = "guest_device" @@ -66,9 +54,7 @@ class _InvalidMacaroonException(Exception): class Auth: """ - FIXME: This class contains a mix of functions for authenticating users - of our client-server API and authenticating events added to room graphs. - The latter should be moved to synapse.handlers.event_auth.EventAuthHandler. + This class contains functions for authenticating users of our client-server API. """ def __init__(self, hs: "HomeServer"): @@ -90,18 +76,6 @@ class Auth: self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users - async def check_from_context( - self, room_version: str, event, context, do_sig_check=True - ) -> None: - auth_event_ids = event.auth_event_ids() - auth_events_by_id = await self.store.get_events(auth_event_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} - - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - event_auth.check( - room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check - ) - async def check_user_in_room( self, room_id: str, @@ -152,13 +126,6 @@ class Auth: raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - async def check_host_in_room(self, room_id: str, host: str) -> bool: - with Measure(self.clock, "check_host_in_room"): - return await self.store.is_host_joined(room_id, host) - - def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]: - return event_auth.get_public_keys(invite_event) - async def get_user_by_req( self, request: SynapseRequest, @@ -489,44 +456,6 @@ class Auth: """ return await self.store.is_server_admin(user) - def compute_auth_events( - self, - event: Union[EventBase, EventBuilder], - current_state_ids: StateMap[str], - for_verification: bool = False, - ) -> List[str]: - """Given an event and current state return the list of event IDs used - to auth an event. - - If `for_verification` is False then only return auth events that - should be added to the event's `auth_events`. - - Returns: - List of event IDs. - """ - - if event.type == EventTypes.Create: - return [] - - # Currently we ignore the `for_verification` flag even though there are - # some situations where we can drop particular auth events when adding - # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publicly joinable). Dropping event IDs has the - # advantage that the auth chain for the room grows slower, but we use - # the auth chain in state resolution v2 to order events, which means - # care must be taken if dropping events to ensure that it doesn't - # introduce undesirable "state reset" behaviour. - # - # All of which sounds a bit tricky so we don't bother for now. - - auth_ids = [] - for etype, state_key in event_auth.auth_types_for_event(event): - auth_ev_id = current_state_ids.get((etype, state_key)) - if auth_ev_id: - auth_ids.append(auth_ev_id) - - return auth_ids - async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: """Determine whether the user is allowed to edit the room's entry in the published room list. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index fb48ec8541..26e3950859 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -34,7 +34,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string if TYPE_CHECKING: - from synapse.api.auth import Auth + from synapse.handlers.event_auth import EventAuthHandler from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ class EventBuilder: """ _state: StateHandler - _auth: "Auth" + _event_auth_handler: "EventAuthHandler" _store: DataStore _clock: Clock _hostname: str @@ -125,7 +125,9 @@ class EventBuilder: state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_event_ids = self._auth.compute_auth_events(self, state_ids) + auth_event_ids = self._event_auth_handler.compute_auth_events( + self, state_ids + ) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: @@ -193,7 +195,7 @@ class EventBuilderFactory: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() def new(self, room_version: str, key_values: dict) -> EventBuilder: """Generate an event builder appropriate for the given room version @@ -229,7 +231,7 @@ class EventBuilderFactory: return EventBuilder( store=self.store, state=self.state, - auth=self.auth, + event_auth_handler=self._event_auth_handler, clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e93b7577fe..b312d0b809 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -108,9 +108,9 @@ class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.auth = hs.get_auth() self.handler = hs.get_federation_handler() self.state = hs.get_state_handler() + self._event_auth_handler = hs.get_event_auth_handler() self.device_handler = hs.get_device_handler() @@ -420,7 +420,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -453,7 +453,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 989996b628..41dbdfd0a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Collection, Optional +from typing import TYPE_CHECKING, Collection, List, Optional, Union +from synapse import event_auth from synapse.api.constants import ( EventTypes, JoinRules, @@ -20,9 +21,11 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, ) from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersion +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase +from synapse.events.builder import EventBuilder from synapse.types import StateMap +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,8 +37,63 @@ class EventAuthHandler: """ def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastore() + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ) -> None: + auth_event_ids = event.auth_event_ids() + auth_events_by_id = await self._store.get_events(auth_event_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} + + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + event_auth.check( + room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check + ) + + def compute_auth_events( + self, + event: Union[EventBase, EventBuilder], + current_state_ids: StateMap[str], + for_verification: bool = False, + ) -> List[str]: + """Given an event and current state return the list of event IDs used + to auth an event. + + If `for_verification` is False then only return auth events that + should be added to the event's `auth_events`. + + Returns: + List of event IDs. + """ + + if event.type == EventTypes.Create: + return [] + + # Currently we ignore the `for_verification` flag even though there are + # some situations where we can drop particular auth events when adding + # to the event's `auth_events` (e.g. joins pointing to previous joins + # when room is publicly joinable). Dropping event IDs has the + # advantage that the auth chain for the room grows slower, but we use + # the auth chain in state resolution v2 to order events, which means + # care must be taken if dropping events to ensure that it doesn't + # introduce undesirable "state reset" behaviour. + # + # All of which sounds a bit tricky so we don't bother for now. + + auth_ids = [] + for etype, state_key in event_auth.auth_types_for_event(event): + auth_ev_id = current_state_ids.get((etype, state_key)) + if auth_ev_id: + auth_ids.append(auth_ev_id) + + return auth_ids + + async def check_host_in_room(self, room_id: str, host: str) -> bool: + with Measure(self._clock, "check_host_in_room"): + return await self._store.is_host_joined(room_id, host) + async def check_restricted_join_rules( self, state_ids: StateMap[str], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d929c65131..991ec9919a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -250,7 +250,9 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) if not is_in_room: logger.info( "Ignoring PDU from %s as we're not in the room", @@ -1674,7 +1676,9 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) # now check that we are *still* in the room - is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) if not is_in_room: logger.info( "Got /make_join request for room %s we are no longer in", @@ -1705,7 +1709,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) @@ -1877,7 +1881,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: @@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: @@ -2111,7 +2115,7 @@ class FederationHandler(BaseHandler): async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2146,7 +2150,9 @@ class FederationHandler(BaseHandler): ) if event: - in_room = await self.auth.check_host_in_room(event.room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room( + event.room_id, origin + ) if not in_room: raise AuthError(403, "Host not in room.") @@ -2499,7 +2505,7 @@ class FederationHandler(BaseHandler): latest_events: List[str], limit: int, ) -> List[EventBase]: - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2562,7 +2568,7 @@ class FederationHandler(BaseHandler): if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) @@ -2991,7 +2997,7 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if await self.auth.check_host_in_room(room_id, self.hs.hostname): + if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname): room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) @@ -3011,7 +3017,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -3054,7 +3062,9 @@ class FederationHandler(BaseHandler): ) try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e @@ -3142,7 +3152,7 @@ class FederationHandler(BaseHandler): last_exception = None # type: Optional[Exception] # for each public key in the 3pid invite event - for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + for public_key_object in event_auth.get_public_keys(invite_event): try: # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 364c5cd2d3..66e40a915d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -385,6 +385,7 @@ class EventCreationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state = hs.get_state_handler() @@ -597,7 +598,7 @@ class EventCreationHandler: (e.type, e.state_key): e.event_id for e in auth_events } # Actually strip down and use the necessary auth events - auth_event_ids = self.auth.compute_auth_events( + auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, current_state_ids=auth_event_state_map, for_verification=False, @@ -1056,7 +1057,9 @@ class EventCreationHandler: assert event.content["membership"] == Membership.LEAVE else: try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -1381,7 +1384,7 @@ class EventCreationHandler: raise AuthError(403, "Redacting server ACL events is not permitted") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_map = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 835d874cee..579b1b93c5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() + self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config # Room state based off defined presets @@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler): }, ) old_room_version = await self.store.get_room_version_id(old_room_id) - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( old_room_version, tombstone_event, tombstone_context ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 266f369883..b585057ec3 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -472,7 +472,7 @@ class SpaceSummaryHandler: # If this is a request over federation, check if the host is in the room or # is in one of the spaces specified via the join rules. elif origin: - if await self._auth.check_host_in_room(room_id, origin): + if await self._event_auth_handler.check_host_in_room(room_id, origin): return True # Alternately, if the host has a user in any of the spaces specified @@ -485,7 +485,9 @@ class SpaceSummaryHandler: await self._event_auth_handler.get_rooms_that_allow_join(state_ids) ) for space_id in allowed_rooms: - if await self._auth.check_host_in_room(space_id, origin): + if await self._event_auth_handler.check_host_in_room( + space_id, origin + ): return True # otherwise, check if the room is peekable diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 350646f458..669ea462e2 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -104,7 +104,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a # time. Keyed off room_id. @@ -172,7 +172,7 @@ class BulkPushRuleEvaluator: # not having a power level event is an extreme edge case auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index dfb9b3a0fa..18e92e90d7 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -734,7 +734,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() # We don't actually check signatures in tests, so lets just create a # random key to use. @@ -846,7 +846,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): builder = EventBuilder( state=self.state, - auth=self.auth, + event_auth_handler=self._event_auth_handler, store=self.store, clock=self.clock, hostname=hostname, -- cgit 1.5.1 From 7a5873277ef456e8446a05468ccae2d81e363977 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Jul 2021 16:32:12 +0100 Subject: Add support for evicting cache entries based on last access time. (#10205) --- changelog.d/10205.feature | 1 + docs/sample_config.yaml | 62 ++++++----- mypy.ini | 1 + synapse/app/_base.py | 11 +- synapse/config/_base.pyi | 2 + synapse/config/cache.py | 70 +++++++----- synapse/util/caches/lrucache.py | 237 ++++++++++++++++++++++++++++++++++------ synapse/util/linked_list.py | 150 +++++++++++++++++++++++++ tests/util/test_lrucache.py | 46 +++++++- 9 files changed, 485 insertions(+), 95 deletions(-) create mode 100644 changelog.d/10205.feature create mode 100644 synapse/util/linked_list.py (limited to 'tests') diff --git a/changelog.d/10205.feature b/changelog.d/10205.feature new file mode 100644 index 0000000000..db3fd22587 --- /dev/null +++ b/changelog.d/10205.feature @@ -0,0 +1 @@ +Add support for evicting cache entries based on last access time. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6fcc022b47..c04aca1f42 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -673,35 +673,41 @@ retention: #event_cache_size: 10K caches: - # Controls the global cache factor, which is the default cache factor - # for all caches if a specific factor for that cache is not otherwise - # set. - # - # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment - # variable. Setting by environment variable takes priority over - # setting through the config file. - # - # Defaults to 0.5, which will half the size of all caches. - # - #global_factor: 1.0 + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 - # A dictionary of cache name to cache factor for that individual - # cache. Overrides the global cache factor for a given cache. - # - # These can also be set through environment variables comprised - # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital - # letters and underscores. Setting by environment variable - # takes priority over setting through the config file. - # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 - # - # Some caches have '*' and other characters that are not - # alphanumeric or underscores. These caches can be named with or - # without the special characters stripped. For example, to specify - # the cache factor for `*stateGroupCache*` via an environment - # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. - # - per_cache_factors: - #get_users_who_share_room_with_user: 2.0 + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + # Controls how long an entry can be in a cache without having been + # accessed before being evicted. Defaults to None, which means + # entries are never evicted based on time. + # + #expiry_time: 30m ## Database ## diff --git a/mypy.ini b/mypy.ini index c4ff0e6618..72ce932d73 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,6 +75,7 @@ files = synapse/util/daemonize.py, synapse/util/hash.py, synapse/util/iterutils.py, + synapse/util/linked_list.py, synapse/util/metrics.py, synapse/util/macaroons.py, synapse/util/module_loader.py, diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8879136881..b30571fe49 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -21,7 +21,7 @@ import socket import sys import traceback import warnings -from typing import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Awaitable, Callable, Iterable from cryptography.utils import CryptographyDeprecationWarning from typing_extensions import NoReturn @@ -41,10 +41,14 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats +from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict @@ -312,7 +316,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -async def start(hs: "synapse.server.HomeServer"): +async def start(hs: "HomeServer"): """ Start a Synapse server or worker. @@ -365,6 +369,9 @@ async def start(hs: "synapse.server.HomeServer"): load_legacy_spam_checkers(hs) + # If we've configured an expiry time for caches, start the background job now. + setup_expire_lru_cache_entries(hs) + # It is now safe to start your Synapse. hs.start_listening() hs.get_datastore().db_pool.start_profiling() diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 23ca0c83c1..06fbd1166b 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -5,6 +5,7 @@ from synapse.config import ( api, appservice, auth, + cache, captcha, cas, consent, @@ -88,6 +89,7 @@ class RootConfig: tracer: tracer.TracerConfig redis: redis.RedisConfig modules: modules.ModulesConfig + caches: cache.CacheConfig federation: federation.FederationConfig config_classes: List = ... diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 91165ee1ce..7789b40323 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -116,35 +116,41 @@ class CacheConfig(Config): #event_cache_size: 10K caches: - # Controls the global cache factor, which is the default cache factor - # for all caches if a specific factor for that cache is not otherwise - # set. - # - # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment - # variable. Setting by environment variable takes priority over - # setting through the config file. - # - # Defaults to 0.5, which will half the size of all caches. - # - #global_factor: 1.0 - - # A dictionary of cache name to cache factor for that individual - # cache. Overrides the global cache factor for a given cache. - # - # These can also be set through environment variables comprised - # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital - # letters and underscores. Setting by environment variable - # takes priority over setting through the config file. - # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 - # - # Some caches have '*' and other characters that are not - # alphanumeric or underscores. These caches can be named with or - # without the special characters stripped. For example, to specify - # the cache factor for `*stateGroupCache*` via an environment - # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. - # - per_cache_factors: - #get_users_who_share_room_with_user: 2.0 + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + # Controls how long an entry can be in a cache without having been + # accessed before being evicted. Defaults to None, which means + # entries are never evicted based on time. + # + #expiry_time: 30m """ def read_config(self, config, **kwargs): @@ -200,6 +206,12 @@ class CacheConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + expiry_time = cache_config.get("expiry_time") + if expiry_time: + self.expiry_time_msec = self.parse_duration(expiry_time) + else: + self.expiry_time_msec = None + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index d89e9d9b1d..4b9d0433ff 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading +import weakref from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -31,10 +34,19 @@ from typing import ( from typing_extensions import Literal +from twisted.internet import reactor + from synapse.config import cache as cache_config -from synapse.util import caches +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.util import Clock, caches from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.linked_list import ListNode + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) try: from pympler.asizeof import Asizer @@ -82,19 +94,126 @@ def enumerate_leaves(node, depth): yield m +P = TypeVar("P") + + +class _TimedListNode(ListNode[P]): + """A `ListNode` that tracks last access time.""" + + __slots__ = ["last_access_ts_secs"] + + def update_last_access(self, clock: Clock): + self.last_access_ts_secs = int(clock.time()) + + +# Whether to insert new cache entries to the global list. We only add to it if +# time based eviction is enabled. +USE_GLOBAL_LIST = False + +# A linked list of all cache entries, allowing efficient time based eviction. +GLOBAL_ROOT = ListNode["_Node"].create_root_node() + + +@wrap_as_background_process("LruCache._expire_old_entries") +async def _expire_old_entries(clock: Clock, expiry_seconds: int): + """Walks the global cache list to find cache entries that haven't been + accessed in the given number of seconds. + """ + + now = int(clock.time()) + node = GLOBAL_ROOT.prev_node + assert node is not None + + i = 0 + + logger.debug("Searching for stale caches") + + while node is not GLOBAL_ROOT: + # Only the root node isn't a `_TimedListNode`. + assert isinstance(node, _TimedListNode) + + if node.last_access_ts_secs > now - expiry_seconds: + break + + cache_entry = node.get_cache_entry() + next_node = node.prev_node + + # The node should always have a reference to a cache entry and a valid + # `prev_node`, as we only drop them when we remove the node from the + # list. + assert next_node is not None + assert cache_entry is not None + cache_entry.drop_from_cache() + + # If we do lots of work at once we yield to allow other stuff to happen. + if (i + 1) % 10000 == 0: + logger.debug("Waiting during drop") + await clock.sleep(0) + logger.debug("Waking during drop") + + node = next_node + + # If we've yielded then our current node may have been evicted, so we + # need to check that its still valid. + if node.prev_node is None: + break + + i += 1 + + logger.info("Dropped %d items from caches", i) + + +def setup_expire_lru_cache_entries(hs: "HomeServer"): + """Start a background job that expires all cache entries if they have not + been accessed for the given number of seconds. + """ + if not hs.config.caches.expiry_time_msec: + return + + logger.info( + "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000 + ) + + global USE_GLOBAL_LIST + USE_GLOBAL_LIST = True + + clock = hs.get_clock() + clock.looping_call( + _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000 + ) + + class _Node: - __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"] + __slots__ = [ + "_list_node", + "_global_list_node", + "_cache", + "key", + "value", + "callbacks", + "memory", + ] def __init__( self, - prev_node, - next_node, + root: "ListNode[_Node]", key, value, + cache: "weakref.ReferenceType[LruCache]", + clock: Clock, callbacks: Collection[Callable[[], None]] = (), ): - self.prev_node = prev_node - self.next_node = next_node + self._list_node = ListNode.insert_after(self, root) + self._global_list_node = None + if USE_GLOBAL_LIST: + self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) + self._global_list_node.update_last_access(clock) + + # We store a weak reference to the cache object so that this _Node can + # remove itself from the cache. If the cache is dropped we ensure we + # remove our entries in the lists. + self._cache = cache + self.key = key self.value = value @@ -116,11 +235,16 @@ class _Node: self.memory = ( _get_size_of(key) + _get_size_of(value) + + _get_size_of(self._list_node, recurse=False) + _get_size_of(self.callbacks, recurse=False) + _get_size_of(self, recurse=False) ) self.memory += _get_size_of(self.memory, recurse=False) + if self._global_list_node: + self.memory += _get_size_of(self._global_list_node, recurse=False) + self.memory += _get_size_of(self._global_list_node.last_access_ts_secs) + def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None: """Add to stored list of callbacks, removing duplicates.""" @@ -147,6 +271,32 @@ class _Node: self.callbacks = None + def drop_from_cache(self) -> None: + """Drop this node from the cache. + + Ensures that the entry gets removed from the cache and that we get + removed from all lists. + """ + cache = self._cache() + if not cache or not cache.pop(self.key, None): + # `cache.pop` should call `drop_from_lists()`, unless this Node had + # already been removed from the cache. + self.drop_from_lists() + + def drop_from_lists(self) -> None: + """Remove this node from the cache lists.""" + self._list_node.remove_from_list() + + if self._global_list_node: + self._global_list_node.remove_from_list() + + def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None: + """Moves this node to the front of all the lists its in.""" + self._list_node.move_after(cache_list_root) + if self._global_list_node: + self._global_list_node.move_after(GLOBAL_ROOT) + self._global_list_node.update_last_access(clock) + class LruCache(Generic[KT, VT]): """ @@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]): size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, + clock: Optional[Clock] = None, ): """ Args: @@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]): apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config """ + # Default `clock` to something sensible. Note that we rename it to + # `real_clock` so that mypy doesn't think its still `Optional`. + if clock is None: + real_clock = Clock(reactor) + else: + real_clock = clock + cache = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]): # this is exposed for access from outside this class self.metrics = metrics - list_root = _Node(None, None, None, None) - list_root.next_node = list_root - list_root.prev_node = list_root + # We create a single weakref to self here so that we don't need to keep + # creating more each time we create a `_Node`. + weak_ref_to_self = weakref.ref(self) + + list_root = ListNode[_Node].create_root_node() lock = threading.Lock() def evict(): while cache_len() > self.max_size: + # Get the last node in the list (i.e. the oldest node). todelete = list_root.prev_node - evicted_len = delete_node(todelete) - cache.pop(todelete.key, None) + + # The list root should always have a valid `prev_node` if the + # cache is not empty. + assert todelete is not None + + # The node should always have a reference to a cache entry, as + # we only drop the cache entry when we remove the node from the + # list. + node = todelete.get_cache_entry() + assert node is not None + + evicted_len = delete_node(node) + cache.pop(node.key, None) if metrics: metrics.inc_evictions(evicted_len) @@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): - prev_node = list_root - next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value, callbacks) - prev_node.next_node = node - next_node.prev_node = node + node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks) cache[key] = node if size_callback: @@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node): - prev_node = node.prev_node - next_node = node.next_node - prev_node.next_node = next_node - next_node.prev_node = prev_node - prev_node = list_root - next_node = prev_node.next_node - node.prev_node = prev_node - node.next_node = next_node - prev_node.next_node = node - next_node.prev_node = node - - def delete_node(node): - prev_node = node.prev_node - next_node = node.next_node - prev_node.next_node = next_node - next_node.prev_node = prev_node + def move_node_to_front(node: _Node): + node.move_to_front(real_clock, list_root) + + def delete_node(node: _Node) -> int: + node.drop_from_lists() deleted_len = 1 if size_callback: @@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_clear() -> None: - list_root.next_node = list_root - list_root.prev_node = list_root for node in cache.values(): node.run_and_clear_callbacks() + node.drop_from_lists() + + assert list_root.next_node == list_root + assert list_root.prev_node == list_root + cache.clear() if size_callback: cached_cache_len[0] = 0 @@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]): self._on_resize() return True return False + + def __del__(self) -> None: + # We're about to be deleted, so we make sure to clear up all the nodes + # and run callbacks, etc. + # + # This happens e.g. in the sync code where we have an expiring cache of + # lru caches. + self.clear() diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py new file mode 100644 index 0000000000..a456b136f0 --- /dev/null +++ b/synapse/util/linked_list.py @@ -0,0 +1,150 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A circular doubly linked list implementation. +""" + +import threading +from typing import Generic, Optional, Type, TypeVar + +P = TypeVar("P") +LN = TypeVar("LN", bound="ListNode") + + +class ListNode(Generic[P]): + """A node in a circular doubly linked list, with an (optional) reference to + a cache entry. + + The reference should only be `None` for the root node or if the node has + been removed from the list. + """ + + # A lock to protect mutating the list prev/next pointers. + _LOCK = threading.Lock() + + # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)` + # and inherit from `Generic` for some reason + __slots__ = [ + "cache_entry", + "prev_node", + "next_node", + ] + + def __init__(self, cache_entry: Optional[P] = None) -> None: + self.cache_entry = cache_entry + self.prev_node: Optional[ListNode[P]] = None + self.next_node: Optional[ListNode[P]] = None + + @classmethod + def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]": + """Create a new linked list by creating a "root" node, which is a node + that has prev_node/next_node pointing to itself and no associated cache + entry. + """ + root = cls() + root.prev_node = root + root.next_node = root + return root + + @classmethod + def insert_after( + cls: Type[LN], + cache_entry: P, + node: "ListNode[P]", + ) -> LN: + """Create a new list node that is placed after the given node. + + Args: + cache_entry: The associated cache entry. + node: The existing node in the list to insert the new entry after. + """ + new_node = cls(cache_entry) + with cls._LOCK: + new_node._refs_insert_after(node) + return new_node + + def remove_from_list(self): + """Remove this node from the list.""" + with self._LOCK: + self._refs_remove_node_from_list() + + # We drop the reference to the cache entry to break the reference cycle + # between the list node and cache entry, allowing the two to be dropped + # immediately rather than at the next GC. + self.cache_entry = None + + def move_after(self, node: "ListNode"): + """Move this node from its current location in the list to after the + given node. + """ + with self._LOCK: + # We assert that both this node and the target node is still "alive". + assert self.prev_node + assert self.next_node + assert node.prev_node + assert node.next_node + + assert self is not node + + # Remove self from the list + self._refs_remove_node_from_list() + + # Insert self back into the list, after target node + self._refs_insert_after(node) + + def _refs_remove_node_from_list(self): + """Internal method to *just* remove the node from the list, without + e.g. clearing out the cache entry. + """ + if self.prev_node is None or self.next_node is None: + # We've already been removed from the list. + return + + prev_node = self.prev_node + next_node = self.next_node + + prev_node.next_node = next_node + next_node.prev_node = prev_node + + # We set these to None so that we don't get circular references, + # allowing us to be dropped without having to go via the GC. + self.prev_node = None + self.next_node = None + + def _refs_insert_after(self, node: "ListNode"): + """Internal method to insert the node after the given node.""" + + # This method should only be called when we're not already in the list. + assert self.prev_node is None + assert self.next_node is None + + # We expect the given node to be in the list and thus have valid + # prev/next refs. + assert node.next_node + assert node.prev_node + + prev_node = node + next_node = node.next_node + + self.prev_node = prev_node + self.next_node = next_node + + prev_node.next_node = self + next_node.prev_node = self + + def get_cache_entry(self) -> Optional[P]: + """Get the cache entry, returns None if this is the root node (i.e. + cache_entry is None) or if the entry has been dropped. + """ + return self.cache_entry diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 377904e72e..6578f3411e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -15,7 +15,7 @@ from unittest.mock import Mock -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache from tests import unittest @@ -260,3 +260,47 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): self.assertEquals(cache["key3"], [3]) self.assertEquals(cache["key4"], [4]) self.assertEquals(cache["key5"], [5, 6]) + + +class TimeEvictionTestCase(unittest.HomeserverTestCase): + """Test that time based eviction works correctly.""" + + def default_config(self): + config = super().default_config() + + config.setdefault("caches", {})["expiry_time"] = "30m" + + return config + + def test_evict(self): + setup_expire_lru_cache_entries(self.hs) + + cache = LruCache(5, clock=self.hs.get_clock()) + + # Check that we evict entries we haven't accessed for 30 minutes. + cache["key1"] = 1 + cache["key2"] = 2 + + self.reactor.advance(20 * 60) + + self.assertEqual(cache.get("key1"), 1) + + self.reactor.advance(20 * 60) + + # We have only touched `key1` in the last 30m, so we expect that to + # still be in the cache while `key2` should have been evicted. + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), None) + + # Check that re-adding an expired key works correctly. + cache["key2"] = 3 + self.assertEqual(cache.get("key2"), 3) + + self.reactor.advance(20 * 60) + + self.assertEqual(cache.get("key2"), 3) + + self.reactor.advance(20 * 60) + + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), 3) -- cgit 1.5.1 From bcb0962a7250d6c1430ad42f5ed234ffea8f2468 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:08:53 +0200 Subject: Fix deactivate a user if he does not have a profile (#10252) --- changelog.d/10252.bugfix | 1 + synapse/storage/databases/main/profile.py | 8 +-- tests/rest/admin/test_user.py | 86 ++++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 changelog.d/10252.bugfix (limited to 'tests') diff --git a/changelog.d/10252.bugfix b/changelog.d/10252.bugfix new file mode 100644 index 0000000000..c8ddd14528 --- /dev/null +++ b/changelog.d/10252.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.26.0 where only users who have set profile information could be deactivated with erasure enabled. diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 9b4e95e134..ba7075caa5 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -73,20 +73,20 @@ class ProfileWorkerStore(SQLBaseStore): async def set_profile_displayname( self, user_localpart: str, new_displayname: Optional[str] ) -> None: - await self.db_pool.simple_update_one( + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, + values={"displayname": new_displayname}, desc="set_profile_displayname", ) async def set_profile_avatar_url( self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: - await self.db_pool.simple_update_one( + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a34d051734..4fccce34fd 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -939,7 +939,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -950,7 +950,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -960,7 +960,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): @@ -990,7 +990,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): @@ -1006,7 +1006,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): def test_deactivate_user_erase_true(self): """ - Test deactivating an user and set `erase` to `true` + Test deactivating a user and set `erase` to `true` """ # Get user @@ -1016,24 +1016,22 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) - # Deactivate user - body = json.dumps({"erase": True}) - + # Deactivate and erase user channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"erase": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1042,7 +1040,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1053,7 +1051,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): def test_deactivate_user_erase_false(self): """ - Test deactivating an user and set `erase` to `false` + Test deactivating a user and set `erase` to `false` """ # Get user @@ -1063,7 +1061,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1071,13 +1069,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("User1", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"erase": False}) - channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"erase": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1089,7 +1085,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1098,6 +1094,60 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", False) + def test_deactivate_user_erase_true_no_profile(self): + """ + Test deactivating a user and set `erase` to `true` + if user has no profile information (stored in the database table `profiles`). + """ + + # Users normally have an entry in `profiles`, but occasionally they are created without one. + # To test deactivation for users without a profile, we delete the profile information for our user. + self.get_success( + self.store.db_pool.simple_delete_one( + table="profiles", keyvalues={"user_id": "user"} + ) + ) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + # Deactivate and erase user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"erase": True}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) -- cgit 1.5.1