From cab4a52535097c5836fe67c5e09e8350d7ccf03c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 Feb 2020 13:08:43 +0000 Subject: set worker_app for frontend proxy test (#7003) to stop the federationhandler trying to do master stuff --- tests/app/test_frontend_proxy.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'tests') diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 160e55aca9..d3feafa1b7 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -27,6 +27,11 @@ class FrontendProxyTests(HomeserverTestCase): return hs + def default_config(self, name="test"): + c = super().default_config(name) + c["worker_app"] = "synapse.app.frontend_proxy" + return c + def test_listen_http_with_presence_enabled(self): """ When presence is on, the stub servlet will not register. -- cgit 1.5.1 From 9b06d8f8a62dc5c423aa9a694e0759eaf1c3c77e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 28 Feb 2020 10:58:05 +0100 Subject: Fixed set a user as an admin with the new API (#6928) Fix #6910 --- changelog.d/6910.bugfix | 1 + synapse/rest/admin/users.py | 6 +- synapse/storage/data_stores/main/registration.py | 16 +- tests/rest/admin/test_user.py | 218 +++++++++++++++++++---- 4 files changed, 199 insertions(+), 42 deletions(-) create mode 100644 changelog.d/6910.bugfix (limited to 'tests') diff --git a/changelog.d/6910.bugfix b/changelog.d/6910.bugfix new file mode 100644 index 0000000000..707f1ff7b5 --- /dev/null +++ b/changelog.d/6910.bugfix @@ -0,0 +1 @@ +Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/`. Contributed by @dklimpel. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c5b461a236..80f959248d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.admin_handler.set_user_server_admin( - target_user, set_admin_to - ) + await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: if ( @@ -651,6 +649,6 @@ class UserAdminServlet(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.store.set_user_server_admin(target_user, set_admin_to) + await self.store.set_server_admin(target_user, set_admin_to) return 200, {} diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 49306642ed..3e53c8568a 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.db.simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) + + def set_server_admin_txn(txn): + self.db.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cbe4a6a51f..6416fb5d2a 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -16,6 +16,7 @@ import hashlib import hmac import json +import urllib.parse from mock import Mock @@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.url = "/_synapse/admin/v2/users/@bob:test" - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_token, + "GET", url, access_token=self.other_user_token, ) self.render(request) @@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("You are not a server admin", channel.json_body["error"]) request, channel = self.make_request( - "PUT", self.url, access_token=self.other_user_token, content=b"{}", + "PUT", url, access_token=self.other_user_token, content=b"{}", ) self.render(request) @@ -417,24 +420,73 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_requester_is_admin(self): + def test_create_server_admin(self): """ - If the user is a server admin, a new user is created. + Check that a new admin user is created successfully. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user (server admin) body = json.dumps( { "password": "abc123", "admin": True, + "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], } ) + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_create_user(self): + """ + Check that a new regular user is created successfully. + """ + self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user + body = json.dumps( + { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -442,29 +494,38 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(0, channel.json_body["deactivated"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_set_password(self): + """ + Test setting a new password for another user. + """ + self.hs.config.registration_shared_secret = None # Change password body = json.dumps({"password": "hahaha"}) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -472,41 +533,133 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_set_displayname(self): + """ + Test setting the displayname of another user. + """ + self.hs.config.registration_shared_secret = None + # Modify user + body = json.dumps({"displayname": "foobar"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + def test_set_threepid(self): + """ + Test setting threepid for an other user. + """ + self.hs.config.registration_shared_secret = None + + # Delete old and add new threepid to user body = json.dumps( - { - "displayname": "foobar", - "deactivated": True, - "threepids": [{"medium": "email", "address": "bob2@bob.bob"}], - } + {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + def test_deactivate_user(self): + """ + Test deactivating another user. + """ + + # Deactivate user + body = json.dumps({"deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) # the user is deactivated, the threepid will be deleted # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", self.url_other_user, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(1, channel.json_body["deactivated"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + def test_set_user_as_admin(self): + """ + Test setting the admin flag on a user. + """ + self.hs.config.registration_shared_secret = None + + # Set a user as an admin + body = json.dumps({"admin": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) def test_accidental_deactivation_prevention(self): """ @@ -514,13 +667,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): for the deactivated body parameter """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" # Create user body = json.dumps({"password": "abc123"}) request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -532,7 +686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) @@ -546,7 +700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -556,7 +710,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Check user is not deactivated request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) -- cgit 1.5.1 From bbeee33d63c43cb80118c0dccf8abd9d4ac1b8f3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 28 Feb 2020 10:58:05 +0100 Subject: Fixed set a user as an admin with the new API (#6928) Fix #6910 --- changelog.d/6910.bugfix | 1 + synapse/rest/admin/users.py | 6 +- synapse/storage/data_stores/main/registration.py | 16 +- tests/rest/admin/test_user.py | 209 ++++++++++++++++++++--- 4 files changed, 194 insertions(+), 38 deletions(-) create mode 100644 changelog.d/6910.bugfix (limited to 'tests') diff --git a/changelog.d/6910.bugfix b/changelog.d/6910.bugfix new file mode 100644 index 0000000000..707f1ff7b5 --- /dev/null +++ b/changelog.d/6910.bugfix @@ -0,0 +1 @@ +Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/`. Contributed by @dklimpel. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2107b5dc56..064908fbb0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.admin_handler.set_user_server_admin( - target_user, set_admin_to - ) + await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: if ( @@ -648,6 +646,6 @@ class UserAdminServlet(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.store.set_user_server_admin(target_user, set_admin_to) + await self.store.set_server_admin(target_user, set_admin_to) return 200, {} diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 49306642ed..3e53c8568a 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.db.simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) + + def set_server_admin_txn(txn): + self.db.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 490ce8f55d..70688c2494 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -16,6 +16,7 @@ import hashlib import hmac import json +import urllib.parse from mock import Mock @@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.url = "/_synapse/admin/v2/users/@bob:test" - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_token, + "GET", url, access_token=self.other_user_token, ) self.render(request) @@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("You are not a server admin", channel.json_body["error"]) request, channel = self.make_request( - "PUT", self.url, access_token=self.other_user_token, content=b"{}", + "PUT", url, access_token=self.other_user_token, content=b"{}", ) self.render(request) @@ -417,24 +420,73 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_requester_is_admin(self): + def test_create_server_admin(self): """ - If the user is a server admin, a new user is created. + Check that a new admin user is created successfully. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user (server admin) body = json.dumps( { "password": "abc123", "admin": True, + "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], } ) + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_create_user(self): + """ + Check that a new regular user is created successfully. + """ + self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user + body = json.dumps( + { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -442,29 +494,38 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(0, channel.json_body["deactivated"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_set_password(self): + """ + Test setting a new password for another user. + """ + self.hs.config.registration_shared_secret = None # Change password body = json.dumps({"password": "hahaha"}) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -472,38 +533,130 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_set_displayname(self): + """ + Test setting the displayname of another user. + """ + self.hs.config.registration_shared_secret = None + # Modify user + body = json.dumps({"displayname": "foobar"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + def test_set_threepid(self): + """ + Test setting threepid for an other user. + """ + self.hs.config.registration_shared_secret = None + + # Delete old and add new threepid to user body = json.dumps( - { - "displayname": "foobar", - "deactivated": True, - "threepids": [{"medium": "email", "address": "bob2@bob.bob"}], - } + {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + def test_deactivate_user(self): + """ + Test deactivating another user. + """ + + # Deactivate user + body = json.dumps({"deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) # the user is deactivated, the threepid will be deleted # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", self.url_other_user, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(1, channel.json_body["deactivated"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + def test_set_user_as_admin(self): + """ + Test setting the admin flag on a user. + """ + self.hs.config.registration_shared_secret = None + + # Set a user as an admin + body = json.dumps({"admin": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) -- cgit 1.5.1 From b2bd54a2e31d9a248f73fadb184ae9b4cbdb49f9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 2 Mar 2020 16:36:32 +0000 Subject: Add a confirmation step to the SSO login flow --- docs/sample_config.yaml | 34 ++++++++++ synapse/config/_base.pyi | 2 + synapse/config/homeserver.py | 2 + synapse/config/sso.py | 74 +++++++++++++++++++++ synapse/res/templates/sso_redirect_confirm.html | 14 ++++ synapse/rest/client/v1/login.py | 40 ++++++++++-- tests/rest/client/v1/test_login.py | 85 +++++++++++++++++++++++++ 7 files changed, 245 insertions(+), 6 deletions(-) create mode 100644 synapse/config/sso.py create mode 100644 synapse/res/templates/sso_redirect_confirm.html (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 8a036071e1..bbb8a4d934 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1360,6 +1360,40 @@ saml2_config: # # name: value +# Additional settings to use with single-sign on systems such as SAML2 and CAS. +# +sso: + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page for confirmation of redirect during authentication: + # 'sso_redirect_confirm.html'. + # + # When rendering, this template is given three variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * display_url: the same as `redirect_url`, but with the query + # parameters stripped. The intention is to have a + # human-readable URL to show to users, not to use it as + # the final address to redirect to. Needs manual escaping + # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + + # The JWT needs to contain a globally unique "sub" (subject) claim. # #jwt_config: diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 86bc965ee4..3053fc9d27 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -24,6 +24,7 @@ from synapse.config import ( server, server_notices_config, spam_checker, + sso, stats, third_party_event_rules, tls, @@ -57,6 +58,7 @@ class RootConfig: key: key.KeyConfig saml2: saml2_config.SAML2Config cas: cas.CasConfig + sso: sso.SSOConfig jwt: jwt_config.JWTConfig password: password.PasswordConfig email: emailconfig.EmailConfig diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 6e348671c7..b4bca08b20 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -38,6 +38,7 @@ from .saml2_config import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig +from .sso import SSOConfig from .stats import StatsConfig from .third_party_event_rules import ThirdPartyRulesConfig from .tls import TlsConfig @@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig): KeyConfig, SAML2Config, CasConfig, + SSOConfig, JWTConfig, PasswordConfig, EmailConfig, diff --git a/synapse/config/sso.py b/synapse/config/sso.py new file mode 100644 index 0000000000..f426b65b4f --- /dev/null +++ b/synapse/config/sso.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict + +import pkg_resources + +from ._base import Config, ConfigError + + +class SSOConfig(Config): + """SSO Configuration + """ + + section = "sso" + + def read_config(self, config, **kwargs): + sso_config = config.get("sso") or {} # type: Dict[str, Any] + + # Pick a template directory in order of: + # * The sso-specific template_dir + # * /path/to/synapse/install/res/templates + template_dir = sso_config.get("template_dir") + if not template_dir: + template_dir = pkg_resources.resource_filename("synapse", "res/templates",) + + self.sso_redirect_confirm_template_dir = template_dir + + def generate_config_section(self, **kwargs): + return """\ + # Additional settings to use with single-sign on systems such as SAML2 and CAS. + # + sso: + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page for a confirmation step before redirecting back to the client + # with the login token: 'sso_redirect_confirm.html'. + # + # When rendering, this template is given three variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * display_url: the same as `redirect_url`, but with the query + # parameters stripped. The intention is to have a + # human-readable URL to show to users, not to use it as + # the final address to redirect to. Needs manual escaping + # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + """ diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html new file mode 100644 index 0000000000..20a15e1e74 --- /dev/null +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -0,0 +1,14 @@ + + + + + SSO redirect confirmation + + +

The application at {{ display_url | e }} is requesting full access to your {{ server_name }} Matrix account.

+

If you don't recognise this address, you should ignore this and close this tab.

+

+ I trust this address +

+ + \ No newline at end of file diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1294e080dc..1acfd01d8e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -29,6 +29,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest +from synapse.push.mailer import load_jinja2_templates from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart @@ -548,6 +549,13 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() + # Load the redirect page HTML template + self._template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], + )[0] + + self._server_name = hs.config.server_name + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): @@ -592,21 +600,41 @@ class SSOAuthHandler(object): request: client_redirect_url: """ - + # Create a login token login_token = self._macaroon_gen.generate_short_term_login_token( registered_user_id ) - redirect_url = self._add_login_token_to_redirect_url( - client_redirect_url, login_token + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + + # Append the login token to the original redirect URL (i.e. with its query + # parameters kept intact) to build the URL to which the template needs to + # redirect the users once they have clicked on the confirmation link. + redirect_url = self._add_query_param_to_url( + client_redirect_url, "loginToken", login_token + ) + + # Serve the redirect confirmation page + html = self._template.render( + display_url=redirect_url_no_params, + redirect_url=redirect_url, + server_name=self._server_name, ) - request.redirect(redirect_url) + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html),)) + request.write(html.encode("utf8")) finish_request(request) @staticmethod - def _add_login_token_to_redirect_url(url, token): + def _add_query_param_to_url(url, param_name, param): url_parts = list(urllib.parse.urlparse(url)) query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"loginToken": token}) + query.update({param_name: param}) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eae5411325..2b8ad5c753 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,4 +1,7 @@ import json +import urllib.parse + +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import login @@ -252,3 +255,85 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(channel.code, 200, channel.result) + + +class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["enable_registration"] = True + config["cas_config"] = { + "enabled": True, + "server_url": "https://fake.test", + "service_url": "https://matrix.goodserver.com:8448", + } + config["public_baseurl"] = self.base_url + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return """ + + + username + PGTIOU-84678-8a9d... + + https://proxy2/pgtUrl + https://proxy1/pgtUrl + + + + """ + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, proxied_http_client=mocked_http_client, + ) + + return self.hs + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Test that the response is HTML. + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) -- cgit 1.5.1 From b68041df3dcbcf3ca04c500d1712aa22a3c2580c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 2 Mar 2020 17:05:09 +0000 Subject: Add a whitelist for the SSO confirmation step. --- docs/sample_config.yaml | 22 +++++++++++++++++++--- synapse/config/sso.py | 18 ++++++++++++++++++ synapse/rest/client/v1/login.py | 26 ++++++++++++++++++-------- tests/rest/client/v1/test_login.py | 32 +++++++++++++++++++++++++++++--- 4 files changed, 84 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index bbb8a4d934..f719ec696f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1363,6 +1363,22 @@ saml2_config: # Additional settings to use with single-sign on systems such as SAML2 and CAS. # sso: + # A list of client URLs which are whitelisted so that the user does not + # have to confirm giving access to their account to the URL. Any client + # whose URL starts with an entry in the following list will not be subject + # to an additional confirmation step after the SSO login is completed. + # + # WARNING: An entry such as "https://my.client" is insecure, because it + # will also match "https://my.client.evil.site", exposing your users to + # phishing attacks from evil.site. To avoid this, include a slash after the + # hostname: "https://my.client/". + # + # By default, this list is empty. + # + #client_whitelist: + # - https://riot.im/develop + # - https://my.custom.client/ + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # @@ -1372,8 +1388,8 @@ sso: # # Synapse will look for the following templates in this directory: # - # * HTML page for confirmation of redirect during authentication: - # 'sso_redirect_confirm.html'. + # * HTML page for a confirmation step before redirecting back to the client + # with the login token: 'sso_redirect_confirm.html'. # # When rendering, this template is given three variables: # * redirect_url: the URL the user is about to be redirected to. Needs @@ -1381,7 +1397,7 @@ sso: # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). # # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a + # parameters stripped. The intention is to have a # human-readable URL to show to users, not to use it as # the final address to redirect to. Needs manual escaping # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). diff --git a/synapse/config/sso.py b/synapse/config/sso.py index f426b65b4f..56299bd4e4 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -37,11 +37,29 @@ class SSOConfig(Config): self.sso_redirect_confirm_template_dir = template_dir + self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. # sso: + # A list of client URLs which are whitelisted so that the user does not + # have to confirm giving access to their account to the URL. Any client + # whose URL starts with an entry in the following list will not be subject + # to an additional confirmation step after the SSO login is completed. + # + # WARNING: An entry such as "https://my.client" is insecure, because it + # will also match "https://my.client.evil.site", exposing your users to + # phishing attacks from evil.site. To avoid this, include a slash after the + # hostname: "https://my.client/". + # + # By default, this list is empty. + # + #client_whitelist: + # - https://riot.im/develop + # - https://my.custom.client/ + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1acfd01d8e..b2bc7537db 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -556,6 +556,9 @@ class SSOAuthHandler(object): self._server_name = hs.config.server_name + # cast to tuple for use with str.startswith + self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): @@ -605,11 +608,6 @@ class SSOAuthHandler(object): registered_user_id ) - # Remove the query parameters from the redirect URL to get a shorter version of - # it. This is only to display a human-readable URL in the template, but not the - # URL we redirect users to. - redirect_url_no_params = client_redirect_url.split("?")[0] - # Append the login token to the original redirect URL (i.e. with its query # parameters kept intact) to build the URL to which the template needs to # redirect the users once they have clicked on the confirmation link. @@ -617,17 +615,29 @@ class SSOAuthHandler(object): client_redirect_url, "loginToken", login_token ) - # Serve the redirect confirmation page + # if the client is whitelisted, we can redirect straight to it + if client_redirect_url.startswith(self._whitelisted_sso_clients): + request.redirect(redirect_url) + finish_request(request) + return + + # Otherwise, serve the redirect confirmation page. + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + html = self._template.render( display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, - ) + ).encode("utf-8") request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%d" % (len(html),)) - request.write(html.encode("utf8")) + request.write(html) finish_request(request) @staticmethod diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2b8ad5c753..da2c9bfa1e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -268,13 +268,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.redirect_path = "_synapse/client/login/sso/redirect/confirm" config = self.default_config() - config["enable_registration"] = True config["cas_config"] = { "enabled": True, "server_url": "https://fake.test", "service_url": "https://matrix.goodserver.com:8448", } - config["public_baseurl"] = self.base_url async def get_raw(uri, args): """Return an example response payload from a call to the `/proxyValidate` @@ -310,7 +308,7 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ - base_url = "/login/cas/ticket?redirectUrl" + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" redirect_url = "https://dodgy-site.com/" url_parts = list(urllib.parse.urlparse(base_url)) @@ -325,6 +323,7 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.render(request) # Test that the response is HTML. + self.assertEqual(channel.code, 200) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": @@ -337,3 +336,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): # And that it contains our redirect link self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url + """ + redirect_url = "https://legit-site.com/" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) -- cgit 1.5.1 From 7dcbc33a1be04c46b930699c03c15bc759f4b22c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 3 Mar 2020 07:12:45 -0500 Subject: Validate the alt_aliases property of canonical alias events (#6971) --- changelog.d/6971.feature | 1 + synapse/api/errors.py | 1 + synapse/handlers/directory.py | 14 ++-- synapse/handlers/message.py | 47 ++++++++++- synapse/types.py | 15 ++-- tests/handlers/test_directory.py | 66 +++++++-------- tests/rest/client/v1/test_rooms.py | 160 +++++++++++++++++++++++++++++++++++++ tests/test_types.py | 2 +- 8 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 changelog.d/6971.feature (limited to 'tests') diff --git a/changelog.d/6971.feature b/changelog.d/6971.feature new file mode 100644 index 0000000000..ccf02a61df --- /dev/null +++ b/changelog.d/6971.feature @@ -0,0 +1 @@ +Validate the alt_aliases property of canonical alias events. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0c20601600..616942b057 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -66,6 +66,7 @@ class Codes(object): EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + BAD_ALIAS = "M_BAD_ALIAS" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 0b23ca919a..61eb49059b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import collections import logging import string from typing import List @@ -307,15 +305,17 @@ class DirectoryHandler(BaseHandler): send_update = True content.pop("alias", "") - # Filter alt_aliases for the removed alias. - alt_aliases = content.pop("alt_aliases", None) - # If the aliases are not a list (or not found) do not attempt to modify - # the list. - if isinstance(alt_aliases, collections.Sequence): + # Filter the alt_aliases property for the removed alias. Note that the + # value is not modified if alt_aliases is of an unexpected form. + alt_aliases = content.get("alt_aliases") + if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases: send_update = True alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: content["alt_aliases"] = alt_aliases + else: + del content["alt_aliases"] if send_update: yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a0103addd3..0c84c6cec4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -888,19 +888,60 @@ class EventCreationHandler(object): yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) + # Validate a newly added alias or newly added alt_aliases. + + original_alias = None + original_alt_aliases = set() + + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_event = yield self.store.get_event(original_event_id) + + if original_event: + original_alias = original_event.content.get("alias", None) + original_alt_aliases = original_event.content.get("alt_aliases", []) + + # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - if room_alias_str: + directory_handler = self.hs.get_handlers().directory_handler + if room_alias_str and room_alias_str != original_alias: room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if mapping["room_id"] != event.room_id: raise SynapseError( 400, "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM + ) + + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + room_alias = RoomAlias.from_string(alias_str) + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" + % (room_alias_str,), + Codes.BAD_ALIAS, + ) + federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: diff --git a/synapse/types.py b/synapse/types.py index f3cd465735..acf60baddc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -23,7 +23,7 @@ import attr from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): @@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( - 400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL) + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) parts = s[1:].split(":", 1) @@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 400, "Expected %s of the form '%slocalname:domain'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) domain = parts[1] @@ -235,11 +238,13 @@ class GroupID(DomainSpecificString): def from_string(cls, s): group_id = super(GroupID, cls).from_string(s) if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty") + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) if contains_invalid_mxid_characters(group_id.localpart): raise SynapseError( - 400, "Group ID can only contain characters a-z, 0-9, or '=_-./'" + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + Codes.INVALID_PARAM, ) return group_id diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 27b916aed4..3397cfa485 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -88,6 +88,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias_not_allowed(self): + """Removing an alias should be denied if a user does not have the proper permissions.""" room_id = "!8765qwer:test" self.get_success( self.store.create_room_alias_association(self.my_room, room_id, ["test"]) @@ -101,6 +102,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias(self): + """Removing an alias should work when a user does has the proper permissions.""" room_id = "!8765qwer:test" user_id = "@user:test" self.get_success( @@ -159,30 +161,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) self.test_alias = "#test:test" - self.room_alias = RoomAlias.from_string(self.test_alias) + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( - self.room_alias, self.room_id, ["test"], self.admin_user + room_alias, self.room_id, ["test"], self.admin_user ) ) + return room_alias - def test_remove_alias(self): - """Removing an alias that is the canonical alias should remove it there too.""" - # Set this new alias as the canonical alias for this room + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, - "m.room.canonical_alias", - {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, - tok=self.admin_user_tok, + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, ) - data = self.get_success( + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( self.state_handler.get_current_state( self.room_id, EventTypes.CanonicalAlias, "" ) ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) @@ -193,11 +207,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) @@ -205,29 +215,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" - other_room_alias = RoomAlias.from_string(other_test_alias) - self.get_success( - self.store.create_room_alias_association( - other_room_alias, self.room_id, ["test"], self.admin_user - ) - ) + other_room_alias = self._add_alias(other_test_alias) # Set the alias as the canonical alias for this room. - self.helper.send_state( - self.room_id, - "m.room.canonical_alias", + self._set_canonical_alias( { "alias": self.test_alias, "alt_aliases": [self.test_alias, other_test_alias], - }, - tok=self.admin_user_tok, + } ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual( data["content"]["alt_aliases"], [self.test_alias, other_test_alias] @@ -240,11 +238,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2f3df5f88f..7dd86d0c27 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/test_types.py b/tests/test_types.py index 8d97c751ea..480bea1bdc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase): self.fail("Parsing '%s' should raise exception" % id_string) except SynapseError as exc: self.assertEqual(400, exc.code) - self.assertEqual("M_UNKNOWN", exc.errcode) + self.assertEqual("M_INVALID_PARAM", exc.errcode) class MapUsernameTestCase(unittest.TestCase): -- cgit 1.5.1 From 8ef8fb2c1c7c4aeb80fce4deea477b37754ce539 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 4 Mar 2020 13:11:04 +0000 Subject: Read the room version from database when fetching events (#6874) This is a precursor to giving EventBase objects the knowledge of which room version they belong to. --- changelog.d/6874.misc | 1 + synapse/storage/data_stores/main/events_worker.py | 84 ++++++++++++++++++----- tests/replication/slave/storage/test_events.py | 10 +++ 3 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 changelog.d/6874.misc (limited to 'tests') diff --git a/changelog.d/6874.misc b/changelog.d/6874.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6874.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 47a3a26072..ca237c6f12 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -28,9 +28,12 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError -from synapse.api.room_versions import EventFormatVersions -from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process @@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore): # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - original_ev = event_type_from_format_version(format_version)( + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.error( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( event_dict=d, + room_version=room_version, internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) @@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore): of EventFormatVersions. 'None' means the event predates EventFormatVersions (so the event is format V1). + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + * rejected_reason (str|None): if the event was rejected, the reason why. @@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore): """ event_dict = {} for evs in batch_iter(event_ids, 200): - sql = ( - "SELECT " - " e.event_id, " - " e.internal_metadata," - " e.json," - " e.format_version, " - " rej.reason " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " WHERE " - ) + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ clause, args = make_in_list_sql_clause( txn.database_engine, "e.event_id", evs @@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore): "internal_metadata": row[1], "json": row[2], "format_version": row[3], - "rejected_reason": row[4], + "room_version_id": row[4], + "rejected_reason": row[5], "redactions": [], } diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d31210fbe4..f0561b30e3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,6 +15,7 @@ import logging from canonicaljson import encode_canonical_json +from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource @@ -58,6 +59,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super(SlavedEventStoreTestCase, self).setUp() + def prepare(self, *args, **kwargs): + super().prepare(*args, **kwargs) + + self.get_success( + self.master_store.store_room( + ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1, + ) + ) + def tearDown(self): [unpatch() for unpatch in self.unpatches] -- cgit 1.5.1 From 13892776ef7e0b1af2f82c9ca53f7bbd1c60d66f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Mar 2020 11:30:46 -0500 Subject: Allow deleting an alias if the user has sufficient power level (#6986) --- changelog.d/6986.feature | 1 + synapse/api/auth.py | 9 +-- synapse/handlers/directory.py | 107 ++++++++++++++++++++++---------- tests/handlers/test_directory.py | 128 +++++++++++++++++++++++++++++++-------- tox.ini | 1 + 5 files changed, 182 insertions(+), 64 deletions(-) create mode 100644 changelog.d/6986.feature (limited to 'tests') diff --git a/changelog.d/6986.feature b/changelog.d/6986.feature new file mode 100644 index 0000000000..16dea8bd7f --- /dev/null +++ b/changelog.d/6986.feature @@ -0,0 +1 @@ +Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5ca18b4301..c1ade1333b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -539,7 +539,7 @@ class Auth(object): @defer.inlineCallbacks def check_can_change_room_list(self, room_id: str, user: UserID): - """Check if the user is allowed to edit the room's entry in the + """Determine whether the user is allowed to edit the room's entry in the published room list. Args: @@ -570,12 +570,7 @@ class Auth(object): ) user_level = event_auth.get_user_power_level(user_id, auth_events) - if user_level < send_level: - raise AuthError( - 403, - "This server requires you to be a moderator in the room to" - " edit its room list entry", - ) + return user_level >= send_level @staticmethod def has_access_token(request): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 61eb49059b..1d842c369b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -15,7 +15,7 @@ import logging import string -from typing import List +from typing import Iterable, List, Optional from twisted.internet import defer @@ -28,6 +28,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.appservice import ApplicationService from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -55,7 +56,13 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None, creator=None): + def _create_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Optional[Iterable[str]] = None, + creator: Optional[str] = None, + ): # general association creation for both human users and app services for wchar in string.whitespace: @@ -81,17 +88,21 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association( - self, requester, room_alias, room_id, servers=None, check_membership=True, + self, + requester: Requester, + room_alias: RoomAlias, + room_id: str, + servers: Optional[List[str]] = None, + check_membership: bool = True, ): """Attempt to create a new alias Args: - requester (Requester) - room_alias (RoomAlias) - room_id (str) - servers (list[str]|None): List of servers that others servers - should try and join via - check_membership (bool): Whether to check if the user is in the room + requester + room_alias + room_id + servers: Iterable of servers that others servers should try and join via + check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). Returns: @@ -145,15 +156,15 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers, creator=user_id) @defer.inlineCallbacks - def delete_association(self, requester, room_alias): + def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call delete_appservice_association) Args: - requester (Requester): - room_alias (RoomAlias): + requester + room_alias Returns: Deferred[unicode]: room id that the alias used to point to @@ -189,16 +200,16 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - yield self._update_canonical_alias( - requester, requester.user.to_string(), room_id, room_alias - ) + yield self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) return room_id @defer.inlineCallbacks - def delete_appservice_association(self, service, room_alias): + def delete_appservice_association( + self, service: ApplicationService, room_alias: RoomAlias + ): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -208,7 +219,7 @@ class DirectoryHandler(BaseHandler): yield self._delete_association(room_alias) @defer.inlineCallbacks - def _delete_association(self, room_alias): + def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -217,7 +228,7 @@ class DirectoryHandler(BaseHandler): return room_id @defer.inlineCallbacks - def get_association(self, room_alias): + def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): result = yield self.get_association_from_room_alias(room_alias) @@ -282,7 +293,9 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + def _update_canonical_alias( + self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias + ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. @@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): + def get_association_from_room_alias(self, room_alias: RoomAlias): result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias, user_id=None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler): return defer.succeed(True) @defer.inlineCallbacks - def _user_can_delete_alias(self, alias, user_id): + def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + """Determine whether a user can delete an alias. + + One of the following must be true: + + 1. The user created the alias. + 2. The user is a server administrator. + 3. The user has a power-level sufficient to send a canonical alias event + for the current room. + + """ creator = yield self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) - return is_admin + # Resolve the alias to the corresponding room. + room_mapping = yield self.get_association(alias) + room_id = room_mapping["room_id"] + if not room_id: + return False + + res = yield self.auth.check_can_change_room_list( + room_id, UserID.from_string(user_id) + ) + return res @defer.inlineCallbacks - def edit_published_room_list(self, requester, room_id, visibility): + def edit_published_room_list( + self, requester: Requester, room_id: str, visibility: str + ): """Edit the entry of the room in the published room list. requester - room_id (str) - visibility (str): "public" or "private" + room_id + visibility: "public" or "private" """ user_id = requester.user.to_string() @@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler): if room is None: raise SynapseError(400, "Unknown room") - yield self.auth.check_can_change_room_list(room_id, requester.user) + can_change_room_list = yield self.auth.check_can_change_room_list( + room_id, requester.user + ) + if not can_change_room_list: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry", + ) making_public = visibility == "public" if making_public: @@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def edit_published_appservice_room_list( - self, appservice_id, network_id, room_id, visibility + self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public room list. Args: - appservice_id (str): ID of the appservice that owns the list - network_id (str): The ID of the network the list is associated with - room_id (str) - visibility (str): either "public" or "private" + appservice_id: ID of the appservice that owns the list + network_id: The ID of the network the list is associated with + room_id + visibility: either "public" or "private" """ if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 3397cfa485..5e40adba52 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,6 +18,7 @@ from mock import Mock from twisted.internet import defer +import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig @@ -87,52 +88,131 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_delete_alias_not_allowed(self): - """Removing an alias should be denied if a user does not have the proper permissions.""" - room_id = "!8765qwer:test" + def test_incoming_fed_query(self): + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) + ) + + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) + ) + + self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestDeleteAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def _create_alias(self, user): + # Create a new alias to this room. self.get_success( - self.store.create_room_alias_association(self.my_room, room_id, ["test"]) + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], user + ) ) + def test_delete_alias_not_allowed(self): + """A user that doesn't meet the expected guidelines cannot delete an alias.""" + self._create_alias(self.admin_user) self.get_failure( self.handler.delete_association( - create_requester("@user:test"), self.my_room + create_requester(self.test_user), self.room_alias ), synapse.api.errors.AuthError, ) - def test_delete_alias(self): - """Removing an alias should work when a user does has the proper permissions.""" - room_id = "!8765qwer:test" - user_id = "@user:test" - self.get_success( - self.store.create_room_alias_association( - self.my_room, room_id, ["test"], user_id + def test_delete_alias_creator(self): + """An alias creator can delete their own alias.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias ) ) + self.assertEquals(self.room_id, result) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_admin(self): + """A server admin can delete an alias created by another user.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias as the admin. result = self.get_success( - self.handler.delete_association(create_requester(user_id), self.my_room) + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) ) - self.assertEquals(room_id, result) + self.assertEquals(self.room_id, result) - # The alias should not be found. + # Confirm the alias is gone. self.get_failure( - self.handler.get_association(self.my_room), synapse.api.errors.SynapseError + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, ) - def test_incoming_fed_query(self): - self.get_success( - self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] - ) + def test_delete_alias_sufficient_power(self): + """A user with a sufficient power level should be able to delete an alias.""" + self._create_alias(self.admin_user) + + # Increase the user's power level. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"users": {self.test_user: 100}}, + tok=self.admin_user_tok, ) - response = self.get_success( - self.handler.on_directory_query({"room_alias": "#your-room:test"}) + # They can now delete the alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) ) + self.assertEquals(self.room_id, result) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) class CanonicalAliasTestCase(unittest.HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 097ebb8774..7622aa19f1 100644 --- a/tox.ini +++ b/tox.ini @@ -185,6 +185,7 @@ commands = mypy \ synapse/federation/federation_client.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/directory.py \ synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.5.1 From 1d66dce83e58827aae12080552edeaeb357b1997 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 6 Mar 2020 18:14:19 +0000 Subject: Break down monthly active users by appservice_id (#7030) * Break down monthly active users by appservice_id and emit via prometheus. Co-authored-by: Brendan Abolivier --- changelog.d/7030.feature | 1 + synapse/app/homeserver.py | 13 +++++++ .../data_stores/main/monthly_active_users.py | 32 ++++++++++++++++- tests/storage/test_monthly_active_users.py | 42 ++++++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7030.feature (limited to 'tests') diff --git a/changelog.d/7030.feature b/changelog.d/7030.feature new file mode 100644 index 0000000000..fcfdb8d8a1 --- /dev/null +++ b/changelog.d/7030.feature @@ -0,0 +1 @@ +Break down monthly active users by `appservice_id` and emit via Prometheus. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c2a334a2b0..e0fdddfdc9 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -298,6 +298,11 @@ class SynapseHomeServer(HomeServer): # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +current_mau_by_service_gauge = Gauge( + "synapse_admin_mau_current_mau_by_service", + "Current MAU by service", + ["app_service"], +) max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau:registered_reserved_users", @@ -585,12 +590,20 @@ def run(hs): @defer.inlineCallbacks def generate_monthly_active_users(): current_mau_count = 0 + current_mau_count_by_service = {} reserved_users = () store = hs.get_datastore() if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: current_mau_count = yield store.get_monthly_active_count() + current_mau_count_by_service = ( + yield store.get_monthly_active_count_by_service() + ) reserved_users = yield store.get_registered_reserved_users() current_mau_gauge.set(float(current_mau_count)) + + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels(app_service).set(float(count)) + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) max_mau_gauge.set(float(hs.config.max_mau_value)) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 1507a14e09..925bc5691b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users(txn): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - txn.execute(sql) (count,) = txn.fetchone() return count return self.db.runInteraction("count_users", _count_users) + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db.runInteraction("count_users_by_service", _count_users_by_service) + @defer.inlineCallbacks def get_registered_reserved_users(self): """Of the reserved threepids defined in config, which are associated @@ -291,6 +318,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () + ) self._invalidate_cache_and_stream( txn, self.user_last_seen_monthly_active, (user_id,) ) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 3c78faab45..bc53bf0951 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -303,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_monthly_active_count_by_service(self): + appservice1_user1 = "@appservice1_user1:example.com" + appservice1_user2 = "@appservice1_user2:example.com" + + appservice2_user1 = "@appservice2_user1:example.com" + native_user1 = "@native_user1:example.com" + + service1 = "service1" + service2 = "service2" + native = "native" + + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + self.store.register_user(user_id=native_user1, password_hash=None) + self.pump() + + count = self.store.get_monthly_active_count_by_service() + self.assertEqual({}, self.get_success(count)) + + self.store.upsert_monthly_active_user(native_user1) + self.store.upsert_monthly_active_user(appservice1_user1) + self.store.upsert_monthly_active_user(appservice1_user2) + self.store.upsert_monthly_active_user(appservice2_user1) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEqual(4, self.get_success(count)) + + count = self.store.get_monthly_active_count_by_service() + result = self.get_success(count) + + self.assertEqual(2, result[service1]) + self.assertEqual(1, result[service2]) + self.assertEqual(1, result[native]) -- cgit 1.5.1 From 1f5f3ae8b1c5db96d36ac7c104f13553bc4283da Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sun, 8 Mar 2020 14:49:33 +0100 Subject: Add options to disable setting profile info for prevent changes. --- synapse/config/registration.py | 11 +++++++++++ synapse/handlers/profile.py | 10 ++++++++++ tests/handlers/test_profile.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9bb3beedbc..d9f452dcea 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,6 +129,9 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.disable_set_displayname = config.get("disable_set_displayname", False) + self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -330,6 +333,14 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # If enabled, don't let users set their own display names/avatars + # other than for the very first time (unless they are a server admin). + # Useful when provisioning users based on the contents of a 3rd party + # directory and to avoid ambiguities. + # + # disable_set_displayname: False + # disable_set_avatar_url: False + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b..fb7e84f3b8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and self.hs.config.disable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError(400, "Changing displayname is disabled on this server") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +223,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and self.hs.config.disable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError(400, "Changing avatar url is disabled on this server") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..b85520c688 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.config = hs.config @defer.inlineCallbacks def test_get_my_name(self): @@ -90,6 +91,19 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.config.disable_set_displayname = True + + # Set first displayname is allowed, if displayname is null + self.store.set_profile_displayname(self.frank.localpart, "Frank") + + d = self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -147,3 +161,20 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.config.disable_set_avatar_url = True + + # Set first time avatar is allowed, if displayname is null + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + d = self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + + yield self.assertFailure(d, SynapseError) -- cgit 1.5.1 From 06eb5cae08272c401a586991fc81f788825f910b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 9 Mar 2020 08:58:25 -0400 Subject: Remove special auth and redaction rules for aliases events in experimental room ver. (#7037) --- changelog.d/7037.feature | 1 + synapse/api/room_versions.py | 9 +-- synapse/crypto/event_signing.py | 2 +- synapse/event_auth.py | 8 +-- synapse/events/utils.py | 12 ++-- synapse/storage/data_stores/main/events.py | 10 +++- tests/events/test_utils.py | 35 ++++++++++- tests/test_event_auth.py | 93 +++++++++++++++++++++++++++++- 8 files changed, 148 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7037.feature (limited to 'tests') diff --git a/changelog.d/7037.feature b/changelog.d/7037.feature new file mode 100644 index 0000000000..4bc1b3b19f --- /dev/null +++ b/changelog.d/7037.feature @@ -0,0 +1 @@ +Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index cf7ee60d3a..871179749a 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -57,7 +57,7 @@ class RoomVersion(object): state_res = attr.ib() # int; one of the StateResolutionVersions enforce_key_validity = attr.ib() # bool - # bool: before MSC2260, anyone was allowed to send an aliases event + # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool, default=False) @@ -102,12 +102,13 @@ class RoomVersions(object): enforce_key_validity=True, special_case_aliases_auth=True, ) - MSC2260_DEV = RoomVersion( - "org.matrix.msc2260", + MSC2432_DEV = RoomVersion( + "org.matrix.msc2432", RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, + special_case_aliases_auth=False, ) @@ -119,6 +120,6 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V3, RoomVersions.V4, RoomVersions.V5, - RoomVersions.MSC2260_DEV, + RoomVersions.MSC2432_DEV, ) } # type: Dict[str, RoomVersion] diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 5f733c1cf5..0422c43fab 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -140,7 +140,7 @@ def compute_event_signature( Returns: a dictionary in the same format of an event's signatures field. """ - redact_json = prune_event_dict(event_dict) + redact_json = prune_event_dict(room_version, event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) if logger.isEnabledFor(logging.DEBUG): diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 472f165044..46beb5334f 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -137,7 +137,7 @@ def check( raise AuthError(403, "This room has been marked as unfederatable.") # 4. If type is m.room.aliases - if event.type == EventTypes.Aliases: + if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth: # 4a. If event has no state_key, reject if not event.is_state(): raise AuthError(403, "Alias event must be a state event") @@ -152,10 +152,8 @@ def check( ) # 4c. Otherwise, allow. - # This is removed by https://github.com/matrix-org/matrix-doc/pull/2260 - if room_version_obj.special_case_aliases_auth: - logger.debug("Allowing! %s", event) - return + logger.debug("Allowing! %s", event) + return if logger.isEnabledFor(logging.DEBUG): logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index bc6f98ae3b..b75b097e5e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -23,6 +23,7 @@ from frozendict import frozendict from twisted.internet import defer from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -43,7 +44,7 @@ def prune_event(event: EventBase) -> EventBase: the user has specified, but we do want to keep necessary information like type, state_key etc. """ - pruned_event_dict = prune_event_dict(event.get_dict()) + pruned_event_dict = prune_event_dict(event.room_version, event.get_dict()) from . import make_event_from_dict @@ -57,15 +58,12 @@ def prune_event(event: EventBase) -> EventBase: return pruned_event -def prune_event_dict(event_dict): +def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: """Redacts the event_dict in the same way as `prune_event`, except it operates on dicts rather than event objects - Args: - event_dict (dict) - Returns: - dict: A copy of the pruned event dict + A copy of the pruned event dict """ allowed_keys = [ @@ -112,7 +110,7 @@ def prune_event_dict(event_dict): "kick", "redact", ) - elif event_type == EventTypes.Aliases: + elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 8ae23df00a..d593ef47b8 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1168,7 +1168,11 @@ class EventsStore( and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json(prune_event_dict(original_event.get_dict())) + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) else: # Redaction wasn't allowed pruned_json = None @@ -1929,7 +1933,9 @@ class EventsStore( return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json(prune_event_dict(event.get_dict())) + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) # Update the event_json table to replace the event's JSON with the pruned # JSON. diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 45d55b9e94..ab5f5ac549 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( copy_power_levels_contents, @@ -36,9 +37,9 @@ class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ - def run_test(self, evdict, matchdict): + def run_test(self, evdict, matchdict, **kwargs): self.assertEquals( - prune_event(make_event_from_dict(evdict)).get_dict(), matchdict + prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) def test_minimal(self): @@ -128,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase): }, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + self.run_test( + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + }, + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + "signatures": {}, + "unsigned": {}, + }, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + self.run_test( + {"type": "m.room.aliases", "content": {"aliases": ["test"]}}, + { + "type": "m.room.aliases", + "content": {}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC2432_DEV, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index bfa5d6f510..6c2351cf55 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -19,6 +19,7 @@ from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.types import get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -51,7 +52,7 @@ class EventAuthTestCase(unittest.TestCase): _random_state_event(joiner), auth_events, do_sig_check=False, - ), + ) def test_state_default_level(self): """ @@ -87,6 +88,83 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False, + ) + + # Reject an event with no state key. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + + # If the domain of the sender does not match the state key, reject. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Note that the member does *not* need to be in the room. + event_auth.check( + RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator), + auth_events, + do_sig_check=False, + ) + + # No particular checks are done on the state key. + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Per standard auth rules, the member must be in the room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(other), + auth_events, + do_sig_check=False, + ) + # helpers for making events @@ -131,6 +209,19 @@ def _power_levels_event(sender, content): ) +def _alias_event(sender, **kwargs): + data = { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.aliases", + "sender": sender, + "state_key": get_domain_from_id(sender), + "content": {"aliases": []}, + } + data.update(**kwargs) + return make_event_from_dict(data) + + def _random_state_event(sender): return make_event_from_dict( { -- cgit 1.5.1 From 04f4b5f6f87fbba0b2f1a4f011c496de3021c81a Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 19:51:31 +0100 Subject: add tests --- tests/handlers/test_profile.py | 6 +- tests/rest/client/v2_alpha/test_account.py | 308 +++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index b85520c688..98b508c3d4 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -70,7 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() - self.config = hs.config + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -93,7 +93,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_if_disabled(self): - self.config.disable_set_displayname = True + self.hs.config.disable_set_displayname = True # Set first displayname is allowed, if displayname is null self.store.set_profile_displayname(self.frank.localpart, "Frank") @@ -164,7 +164,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): - self.config.disable_set_avatar_url = True + self.hs.config.disable_set_avatar_url = True # Set first time avatar is allowed, if displayname is null self.store.set_profile_avatar_url( diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb..ac9f200de3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -325,3 +325,311 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + return + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test add mail to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test add mail to profile if disabled + """ + self.hs.config.disable_3pid_changes = True + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test delete mail from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + { + "medium": "email", + "address": self.email + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test delete mail from profile if disabled + """ + self.hs.config.disable_3pid_changes = True + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + { + "medium": "email", + "address": self.email + }, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) -- cgit 1.5.1 From 50ea178c201588b5e6b3f93e1af56aef0b4e8368 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 19:57:04 +0100 Subject: lint --- tests/rest/client/v2_alpha/test_account.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index ac9f200de3..e178a53335 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -438,7 +438,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + self.assertEqual( + "3PID changes disabled on this server", channel.json_body["error"] + ) # Get user request, channel = self.make_request( @@ -466,10 +468,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"account/3pid/delete", - { - "medium": "email", - "address": self.email - }, + {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) self.render(request) @@ -503,16 +502,15 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"account/3pid/delete", - { - "medium": "email", - "address": self.email - }, + {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + self.assertEqual( + "3PID changes disabled on this server", channel.json_body["error"] + ) # Get user request, channel = self.make_request( -- cgit 1.5.1 From 7e5f40e7716813f0d32e2efcb32df3c263fbfc63 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 21:00:36 +0100 Subject: fix tests --- tests/handlers/test_profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 98b508c3d4..f8c0da5ced 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -96,7 +96,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.disable_set_displayname = True # Set first displayname is allowed, if displayname is null - self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") d = self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -167,7 +167,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.disable_set_avatar_url = True # Set first time avatar is allowed, if displayname is null - self.store.set_profile_avatar_url( + yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) -- cgit 1.5.1 From 885134529ffd95dd118d3228e69f0e3553f5a6a7 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 22:09:29 +0100 Subject: updates after review --- changelog.d/7053.feature | 2 +- docs/sample_config.yaml | 10 +++++----- synapse/config/registration.py | 16 ++++++++-------- synapse/handlers/profile.py | 8 ++++---- synapse/rest/client/v2_alpha/account.py | 18 ++++++++++++------ tests/handlers/test_profile.py | 6 +++--- tests/rest/client/v2_alpha/test_account.py | 17 +++++++---------- 7 files changed, 40 insertions(+), 37 deletions(-) (limited to 'tests') diff --git a/changelog.d/7053.feature b/changelog.d/7053.feature index 79955b9780..00f47b2a14 100644 --- a/changelog.d/7053.feature +++ b/changelog.d/7053.feature @@ -1 +1 @@ -Add options to disable setting profile info for prevent changes. \ No newline at end of file +Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d3ecffac7d..8333800a10 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1057,18 +1057,18 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process -# If enabled, don't let users set their own display names/avatars +# If disabled, don't let users set their own display names/avatars # other than for the very first time (unless they are a server admin). # Useful when provisioning users based on the contents of a 3rd party # directory and to avoid ambiguities. # -#disable_set_displayname: false -#disable_set_avatar_url: false +#enable_set_displayname: true +#enable_set_avatar_url: true -# If true, stop users from trying to change the 3PIDs associated with +# If false, stop users from trying to change the 3PIDs associated with # their accounts. # -#disable_3pid_changes: false +#enable_3pid_changes: true # Users who register on this homeserver will automatically be joined # to these rooms diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1abc0a79af..d4897ec9b6 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,9 +129,9 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.disable_set_displayname = config.get("disable_set_displayname", False) - self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) - self.disable_3pid_changes = config.get("disable_3pid_changes", False) + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False @@ -334,18 +334,18 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process - # If enabled, don't let users set their own display names/avatars + # If disabled, don't let users set their own display names/avatars # other than for the very first time (unless they are a server admin). # Useful when provisioning users based on the contents of a 3rd party # directory and to avoid ambiguities. # - #disable_set_displayname: false - #disable_set_avatar_url: false + #enable_set_displayname: true + #enable_set_avatar_url: true - # If true, stop users from trying to change the 3PIDs associated with + # If false, stop users from trying to change the 3PIDs associated with # their accounts. # - #disable_3pid_changes: false + #enable_3pid_changes: true # Users who register on this homeserver will automatically be joined # to these rooms diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b049dd8e26..eb85dba015 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,11 +157,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and self.hs.config.disable_set_displayname: + if not by_admin and not self.hs.config.enable_set_displayname: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( - 400, "Changing displayname is disabled on this server" + 400, "Changing display name is disabled on this server", Codes.FORBIDDEN ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: @@ -225,11 +225,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and self.hs.config.disable_set_avatar_url: + if not by_admin and not self.hs.config.enable_set_avatar_url: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( - 400, "Changing avatar url is disabled on this server" + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN ) if len(new_avatar_url) > MAX_AVATAR_URL_LEN: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 97bddf36d9..e40136f2f3 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -599,8 +599,10 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -646,8 +648,10 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -749,8 +753,10 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f8c0da5ced..e600b9777b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -93,7 +93,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_if_disabled(self): - self.hs.config.disable_set_displayname = True + self.hs.config.enable_set_displayname = False # Set first displayname is allowed, if displayname is null yield self.store.set_profile_displayname(self.frank.localpart, "Frank") @@ -164,9 +164,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): - self.hs.config.disable_set_avatar_url = True + self.hs.config.enable_set_avatar_url = False - # Set first time avatar is allowed, if displayname is null + # Set first time avatar is allowed, if avatar is null yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index e178a53335..34e40a36d0 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -412,7 +413,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test add mail to profile if disabled """ - self.hs.config.disable_3pid_changes = True + self.hs.config.enable_3pid_changes = True client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -438,9 +439,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "3PID changes disabled on this server", channel.json_body["error"] - ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -486,7 +485,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test delete mail from profile if disabled """ - self.hs.config.disable_3pid_changes = True + self.hs.config.enable_3pid_changes = True # Add a threepid self.get_success( @@ -508,9 +507,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "3PID changes disabled on this server", channel.json_body["error"] - ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -547,7 +544,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -582,7 +579,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user request, channel = self.make_request( -- cgit 1.5.1 From 39f6595b4ab108cb451072ae251a91117002191c Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 22:13:20 +0100 Subject: lint, fix tests --- synapse/handlers/profile.py | 4 +++- tests/rest/client/v2_alpha/test_account.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index eb85dba015..6aa1c0f5e0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -161,7 +161,9 @@ class BaseProfileHandler(BaseHandler): profile = yield self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( - 400, "Changing display name is disabled on this server", Codes.FORBIDDEN + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 34e40a36d0..99cc9163f3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -413,7 +413,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test add mail to profile if disabled """ - self.hs.config.enable_3pid_changes = True + self.hs.config.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -485,7 +485,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test delete mail from profile if disabled """ - self.hs.config.enable_3pid_changes = True + self.hs.config.enable_3pid_changes = False # Add a threepid self.get_success( -- cgit 1.5.1 From 6a35046363a6f5d41199256c80eef4ea7e385986 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 Mar 2020 11:25:01 +0000 Subject: Revert "Add options to disable setting profile info for prevent changes. (#7053)" This reverts commit 54dd28621b070ca67de9f773fe9a89e1f4dc19da, reversing changes made to 6640460d054e8f4444046a34bdf638921b31c01e. --- changelog.d/7053.feature | 1 - docs/sample_config.yaml | 13 -- synapse/config/registration.py | 17 -- synapse/handlers/profile.py | 16 -- synapse/rest/client/v2_alpha/account.py | 16 -- tests/handlers/test_profile.py | 33 +--- tests/rest/client/v2_alpha/test_account.py | 303 ----------------------------- 7 files changed, 1 insertion(+), 398 deletions(-) delete mode 100644 changelog.d/7053.feature (limited to 'tests') diff --git a/changelog.d/7053.feature b/changelog.d/7053.feature deleted file mode 100644 index 00f47b2a14..0000000000 --- a/changelog.d/7053.feature +++ /dev/null @@ -1 +0,0 @@ -Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 91eff4c8ad..2ff0dd05a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1057,19 +1057,6 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process -# If disabled, don't let users set their own display names/avatars -# (unless they are a server admin) other than for the very first time. -# Useful when provisioning users based on the contents of a 3rd party -# directory and to avoid ambiguities. -# -#enable_set_displayname: true -#enable_set_avatar_url: true - -# If false, stop users from trying to change the 3PIDs associated with -# their accounts. -# -#enable_3pid_changes: true - # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ee737eb40d..9bb3beedbc 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,10 +129,6 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.enable_set_displayname = config.get("enable_set_displayname", True) - self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) - self.enable_3pid_changes = config.get("enable_3pid_changes", True) - self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -334,19 +330,6 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process - # If disabled, don't let users set their own display names/avatars - # (unless they are a server admin) other than for the very first time. - # Useful when provisioning users based on the contents of a 3rd party - # directory and to avoid ambiguities. - # - #enable_set_displayname: true - #enable_set_avatar_url: true - - # If false, stop users from trying to change the 3PIDs associated with - # their accounts. - # - #enable_3pid_changes: true - # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..50ce0c585b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,15 +157,6 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) - if profile.display_name: - raise SynapseError( - 400, - "Changing display name is disabled on this server", - Codes.FORBIDDEN, - ) - if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -227,13 +218,6 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) - if profile.avatar_url: - raise SynapseError( - 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN - ) - if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e40136f2f3..dc837d6c75 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -599,11 +599,6 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -648,11 +643,6 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -748,16 +738,10 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index e600b9777b..d60c124eec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,7 +70,6 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() - self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -91,19 +90,6 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) - @defer.inlineCallbacks - def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False - - # Set first displayname is allowed, if displayname is null - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") - - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." - ) - - yield self.assertFailure(d, SynapseError) - @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -161,20 +147,3 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) - - @defer.inlineCallbacks - def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False - - # Set first time avatar is allowed, if avatar is null - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" - ) - - d = self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", - ) - - yield self.assertFailure(d, SynapseError) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 99cc9163f3..c3facc00eb 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,6 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -326,305 +325,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) - - -class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - self.email_attempts = [] - - def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) - return self.hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.user_id = self.register_user("kermit", "test") - self.user_id_tok = self.login("kermit", "test") - self.email = "test@example.com" - self.url_3pid = b"account/3pid" - - def test_add_email(self): - """Test add mail to profile - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_add_email_if_disabled(self): - """Test add mail to profile if disabled - """ - self.hs.config.enable_3pid_changes = False - - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email(self): - """Test delete mail from profile - """ - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - request, channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email_if_disabled(self): - """Test delete mail from profile if disabled - """ - self.hs.config.enable_3pid_changes = False - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - request, channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to add email without clicking the link - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - client_secret = "foobar" - session_id = "weasle" - - # Attempt to add email without even requesting an email - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def _request_token(self, email, client_secret): - request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - return channel.json_body["sid"] - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) -- cgit 1.5.1 From 60724c46b7dc5300243fd97d5a485564b3e00afe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 17 Mar 2020 07:37:04 -0400 Subject: Remove special casing of `m.room.aliases` events (#7034) --- changelog.d/7034.removal | 1 + synapse/handlers/room.py | 16 +------------ synapse/rest/client/v1/room.py | 12 ---------- tests/rest/admin/test_admin.py | 7 ++++++ tests/rest/client/v1/test_directory.py | 41 +++++++++++++++++++++------------- 5 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 changelog.d/7034.removal (limited to 'tests') diff --git a/changelog.d/7034.removal b/changelog.d/7034.removal new file mode 100644 index 0000000000..be8d20e14f --- /dev/null +++ b/changelog.d/7034.removal @@ -0,0 +1 @@ +Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8ee870f0bb..f580ab2e9f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - new_pl_content = copy_power_levels_contents(old_room_pl_state.content) - - # pre-msc2260 rooms may not have the right setting for aliases. If no other - # value is set, set it now. - events_default = new_pl_content.get("events_default", 0) - new_pl_content.setdefault("events", {}).setdefault( - EventTypes.Aliases, events_default - ) - - logger.debug("Setting correct PLs in new room to %s", new_pl_content) yield self.event_creation_handler.create_and_send_nonmember_event( requester, { @@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": new_pl_content, + "content": old_room_pl_state.content, }, ratelimit=False, ) @@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler): EventTypes.RoomHistoryVisibility: 100, EventTypes.CanonicalAlias: 50, EventTypes.RoomAvatar: 50, - # MSC2260: Allow everybody to send alias events by default - # This will be reudundant on pre-MSC2260 rooms, since the - # aliases event is special-cased. - EventTypes.Aliases: 0, EventTypes.Tombstone: 100, EventTypes.ServerACL: 100, }, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 64f51406fb..bffd43de5f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -189,12 +189,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) - if event_type == EventTypes.Aliases: - # MSC2260 - raise SynapseError( - 400, "Cannot send m.room.aliases events via /rooms/{room_id}/state" - ) - event_dict = { "type": event_type, "content": content, @@ -242,12 +236,6 @@ class RoomSendEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - if event_type == EventTypes.Aliases: - # MSC2260 - raise SynapseError( - 400, "Cannot send m.room.aliases events via /rooms/{room_id}/send" - ) - event_dict = { "type": event_type, "content": content, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index e5984aaad8..0342aed416 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -868,6 +868,13 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) self.helper.send_state( room_id, "m.room.canonical_alias", diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 914cf54927..633b7dbda0 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_cannot_set_alias_via_state_event(self): - self.ensure_user_joined_room() - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(5)]} - request_data = json.dumps(data) - - request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.render(request) - self.assertEqual(channel.code, 400, channel.result) + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) def test_directory_endpoint_not_in_room(self): self.ensure_user_left_room() self.set_alias_via_directory(403) + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + def test_directory_in_room_too_long(self): self.ensure_user_joined_room() self.set_alias_via_directory(400, alias_length=256) + def test_state_event_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + def test_directory_in_room(self): self.ensure_user_joined_room() self.set_alias_via_directory(200) @@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200, channel.result) + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + def set_alias_via_directory(self, expected_code, alias_length=5): url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) data = {"room_id": self.room_id} -- cgit 1.5.1 From c37db0211e36cd298426ff8811e547b0acd10bf4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Mar 2020 22:32:25 +0100 Subject: Share SSL contexts for non-federation requests (#7094) Extends #5794 etc to the SimpleHttpClient so that it also applies to non-federation requests. Fixes #7092. --- changelog.d/7094.misc | 1 + synapse/crypto/context_factory.py | 68 ++++++++++++++-------- synapse/http/client.py | 3 - synapse/http/federation/matrix_federation_agent.py | 2 +- synapse/server.py | 6 +- tests/config/test_tls.py | 29 +++++---- .../federation/test_matrix_federation_agent.py | 6 +- 7 files changed, 71 insertions(+), 44 deletions(-) create mode 100644 changelog.d/7094.misc (limited to 'tests') diff --git a/changelog.d/7094.misc b/changelog.d/7094.misc new file mode 100644 index 0000000000..aa093ee3c0 --- /dev/null +++ b/changelog.d/7094.misc @@ -0,0 +1 @@ +Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections. diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index e93f0b3705..a5a2a7815d 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -75,7 +75,7 @@ class ServerContextFactory(ContextFactory): @implementer(IPolicyForHTTPS) -class ClientTLSOptionsFactory(object): +class FederationPolicyForHTTPS(object): """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers for federation. @@ -103,15 +103,15 @@ class ClientTLSOptionsFactory(object): # let us do). minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] - self._verify_ssl = CertificateOptions( + _verify_ssl = CertificateOptions( trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS ) - self._verify_ssl_context = self._verify_ssl.getContext() - self._verify_ssl_context.set_info_callback(self._context_info_cb) + self._verify_ssl_context = _verify_ssl.getContext() + self._verify_ssl_context.set_info_callback(_context_info_cb) - self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) - self._no_verify_ssl_context = self._no_verify_ssl.getContext() - self._no_verify_ssl_context.set_info_callback(self._context_info_cb) + _no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) + self._no_verify_ssl_context = _no_verify_ssl.getContext() + self._no_verify_ssl_context.set_info_callback(_context_info_cb) def get_options(self, host: bytes): @@ -136,23 +136,6 @@ class ClientTLSOptionsFactory(object): return SSLClientConnectionCreator(host, ssl_context, should_verify) - @staticmethod - def _context_info_cb(ssl_connection, where, ret): - """The 'information callback' for our openssl context object.""" - # we assume that the app_data on the connection object has been set to - # a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator) - tls_protocol = ssl_connection.get_app_data() - try: - # ... we further assume that SSLClientConnectionCreator has set the - # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. - tls_protocol._synapse_tls_verifier.verify_context_info_cb( - ssl_connection, where - ) - except: # noqa: E722, taken from the twisted implementation - logger.exception("Error during info_callback") - f = Failure() - tls_protocol.failVerification(f) - def creatorForNetloc(self, hostname, port): """Implements the IPolicyForHTTPS interace so that this can be passed directly to agents. @@ -160,6 +143,43 @@ class ClientTLSOptionsFactory(object): return self.get_options(hostname) +@implementer(IPolicyForHTTPS) +class RegularPolicyForHTTPS(object): + """Factory for Twisted SSLClientConnectionCreators that are used to make connections + to remote servers, for other than federation. + + Always uses the same OpenSSL context object, which uses the default OpenSSL CA + trust root. + """ + + def __init__(self): + trust_root = platformTrust() + self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext() + self._ssl_context.set_info_callback(_context_info_cb) + + def creatorForNetloc(self, hostname, port): + return SSLClientConnectionCreator(hostname, self._ssl_context, True) + + +def _context_info_cb(ssl_connection, where, ret): + """The 'information callback' for our openssl context objects. + + Note: Once this is set as the info callback on a Context object, the Context should + only be used with the SSLClientConnectionCreator. + """ + # we assume that the app_data on the connection object has been set to + # a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator) + tls_protocol = ssl_connection.get_app_data() + try: + # ... we further assume that SSLClientConnectionCreator has set the + # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. + tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where) + except: # noqa: E722, taken from the twisted implementation + logger.exception("Error during info_callback") + f = Failure() + tls_protocol.failVerification(f) + + @implementer(IOpenSSLClientConnectionCreator) class SSLClientConnectionCreator(object): """Creates openssl connection objects for client connections. diff --git a/synapse/http/client.py b/synapse/http/client.py index d4c285445e..3797545824 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -244,9 +244,6 @@ class SimpleHttpClient(object): pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) pool.cachedConnectionTimeout = 2 * 60 - # The default context factory in Twisted 14.0.0 (which we require) is - # BrowserLikePolicyForHTTPS which will do regular cert validation - # 'like a browser' self.agent = ProxyAgent( self.reactor, connectTimeout=15, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 647d26dc56..f5f917f5ae 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -45,7 +45,7 @@ class MatrixFederationAgent(object): Args: reactor (IReactor): twisted reactor to use for underlying requests - tls_client_options_factory (ClientTLSOptionsFactory|None): + tls_client_options_factory (FederationPolicyForHTTPS|None): factory to use for fetching client tls options, or none to disable TLS. _srv_resolver (SrvResolver|None): diff --git a/synapse/server.py b/synapse/server.py index fd2f69e928..1b980371de 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -26,7 +26,6 @@ import logging import os from twisted.mail.smtp import sendmail -from twisted.web.client import BrowserLikePolicyForHTTPS from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -35,6 +34,7 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory +from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker @@ -310,7 +310,7 @@ class HomeServer(object): return ( InsecureInterceptableContextFactory() if self.config.use_insecure_ssl_client_just_for_testing_do_not_use - else BrowserLikePolicyForHTTPS() + else RegularPolicyForHTTPS() ) def build_simple_http_client(self): @@ -420,7 +420,7 @@ class HomeServer(object): return PusherPool(self) def build_http_client(self): - tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config ) return MatrixFederationHttpClient(self, tls_client_options_factory) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 1be6ff563b..ec32d4b1ca 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -23,7 +23,7 @@ from OpenSSL import SSL from synapse.config._base import Config, RootConfig from synapse.config.tls import ConfigError, TlsConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from tests.unittest import TestCase @@ -180,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_tls_client_minimum_set_passed_through_1_0(self): """ @@ -195,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has not had any of the NO_TLS set. - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_acme_disabled_in_generated_config_no_acme_domain_provied(self): """ @@ -273,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) # Not in the whitelist opts = cf.get_options(b"notexample.com") @@ -282,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= # Caught by the wildcard opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) self.assertFalse(opts._verifier._verify_certs) + + +def _get_ssl_context_options(ssl_context: SSL.Context) -> int: + """get the options bits from an openssl context object""" + # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to + # use the low-level interface + return SSL._lib.SSL_CTX_get_options(ssl_context._context) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index cfcd98ff7d..fdc1d918ff 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -31,7 +31,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( @@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - self.tls_factory = ClientTLSOptionsFactory(config) + self.tls_factory = FederationPolicyForHTTPS(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) @@ -715,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase): config = default_config("test", parse=True) # Build a new agent and WellKnownResolver with a different tls factory - tls_factory = ClientTLSOptionsFactory(config) + tls_factory = FederationPolicyForHTTPS(config) agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, -- cgit 1.5.1 From 4a17a647a9508b70de35130fd82e3e21474270a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 16:46:41 +0000 Subject: Improve get auth chain difference algorithm. (#7095) It was originally implemented by pulling the full auth chain of all state sets out of the database and doing set comparison. However, that can take a lot work if the state and auth chains are large. Instead, lets try and fetch the auth chains at the same time and calculate the difference on the fly, allowing us to bail early if all the auth chains converge. Assuming that the auth chains do converge more often than not, this should improve performance. Hopefully. --- changelog.d/7095.misc | 1 + synapse/state/__init__.py | 28 ++-- synapse/state/v2.py | 32 +---- .../storage/data_stores/main/event_federation.py | 150 +++++++++++++++++++- tests/state/test_v2.py | 13 +- tests/storage/test_event_federation.py | 157 ++++++++++++++++++--- 6 files changed, 310 insertions(+), 71 deletions(-) create mode 100644 changelog.d/7095.misc (limited to 'tests') diff --git a/changelog.d/7095.misc b/changelog.d/7095.misc new file mode 100644 index 0000000000..44fc9f616f --- /dev/null +++ b/changelog.d/7095.misc @@ -0,0 +1 @@ +Attempt to improve performance of state res v2 algorithm. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index df7a4f6a89..4afefc6b1d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -662,28 +662,16 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]): - """Gets the full auth chain for a set of events (including rejected - events). - - Includes the given event IDs in the result. - - Note that: - 1. All events must be state events. - 2. For v1 rooms this may not have the full auth chain in the - presence of rejected events - - Args: - event_ids: The event IDs of the events to fetch the auth chain for. - Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + Deferred[Set[str]]: Set of event IDs. """ - return self.store.get_auth_chain_ids( - event_ids, include_given=True, ignore_events=ignore_events, - ) + return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 0ffe6d8c14..18484e2fa6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Returns: Deferred[set[str]]: Set of event IDs """ - common = set(itervalues(state_sets[0])).intersection( - *(itervalues(s) for s in state_sets[1:]) - ) - - auth_sets = [] - for state_set in state_sets: - auth_ids = { - eid - for key, eid in iteritems(state_set) - if ( - key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) - or key - in ( - (EventTypes.PowerLevels, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), - ) - ) - and eid not in common - } - auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) - auth_ids.update(auth_chain) - - auth_sets.append(auth_ids) - - intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) - union = set().union(*auth_sets) + difference = yield state_res_store.get_auth_chain_difference( + [set(state_set.values()) for state_set in state_sets] + ) - return union - intersection + return difference def _seperate(state_sets): diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 49a7b8b433..62d4e9f599 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,7 +14,7 @@ # limitations under the License. import itertools import logging -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from six.moves.queue import Empty, PriorityQueue @@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediately show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # topological ordering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # We need to get the depth of the initial events for sorting purposes. + sql = """ + SELECT depth, event_id FROM events + WHERE %s + ORDER BY depth ASC + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", initial_events + ) + txn.execute(sql % (clause,), args) + + # The sorted list of events whose auth chains we should walk. + search = txn.fetchall() # type: List[Tuple[int, str]] + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) + + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + def get_oldest_events_in_room(self, room_id): return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5059ade850..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -603,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids, ignore_events): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -617,9 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -629,7 +626,7 @@ class TestStateResolutionStore(object): stack = list(event_ids) while stack: event_id = stack.pop() - if event_id in result or event_id in ignore_events: + if event_id in result: continue result.add(event_id) @@ -639,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a331517f4d..3aeec0dc0f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): room_id = "@ROOM:local" @@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i) + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) # this should get the last ten - r = yield self.store.get_prev_events_for_room(room_id) + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) for i in range(0, 10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) - @defer.inlineCallbacks def test_get_rooms_with_many_extremities(self): room1 = "#room1" room2 = "#room2" @@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i, room1) - yield self.store.db.runInteraction("insert", insert_event, i, room2) - yield self.store.db.runInteraction("insert", insert_event, i, room3) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) # Test simple case - r = yield self.store.get_rooms_with_many_extremities(5, 5, []) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) self.assertEqual(len(r), 3) # Does filter work? - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) self.assertTrue(room2 in r) self.assertTrue(room3 in r) self.assertEqual(len(r), 2) - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) self.assertEqual(r, [room3]) # Does filter and limit work? - r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id, stream_ordering): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + next_stream_ordering = 0 + for event_id in auth_graph: + next_stream_ordering += 1 + self.get_success( + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) -- cgit 1.5.1 From 51f358e2fe4b568009de46c0130bd6843ed8215b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Apr 2020 10:52:55 -0400 Subject: Do not treat display names as globs for push rules. (#7271) --- changelog.d/7271.bugfix | 1 + synapse/push/push_rule_evaluator.py | 69 +++++++++++++++++++--------------- tests/push/test_push_rule_evaluator.py | 65 ++++++++++++++++++++++++++++++++ tox.ini | 1 + 4 files changed, 106 insertions(+), 30 deletions(-) create mode 100644 changelog.d/7271.bugfix create mode 100644 tests/push/test_push_rule_evaluator.py (limited to 'tests') diff --git a/changelog.d/7271.bugfix b/changelog.d/7271.bugfix new file mode 100644 index 0000000000..e8315e4ce4 --- /dev/null +++ b/changelog.d/7271.bugfix @@ -0,0 +1 @@ +Do not treat display names as globs in push rules. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index b1587183a8..4cd702b5fa 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,9 +16,11 @@ import logging import re +from typing import Pattern from six import string_types +from synapse.events import EventBase from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -56,18 +58,18 @@ def _test_ineq_condition(condition, number): rhs = m.group(2) if not rhs.isdigit(): return False - rhs = int(rhs) + rhs_int = int(rhs) if ineq == "" or ineq == "==": - return number == rhs + return number == rhs_int elif ineq == "<": - return number < rhs + return number < rhs_int elif ineq == ">": - return number > rhs + return number > rhs_int elif ineq == ">=": - return number >= rhs + return number >= rhs_int elif ineq == "<=": - return number <= rhs + return number <= rhs_int else: return False @@ -83,7 +85,13 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count, sender_power_level, power_levels): + def __init__( + self, + event: EventBase, + room_member_count: int, + sender_power_level: int, + power_levels: dict, + ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level @@ -92,7 +100,7 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name): + def matches(self, condition: dict, user_id: str, display_name: str) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -106,7 +114,7 @@ class PushRuleEvaluatorForEvent(object): else: return True - def _event_match(self, condition, user_id): + def _event_match(self, condition: dict, user_id: str) -> bool: pattern = condition.get("pattern", None) if not pattern: @@ -134,7 +142,7 @@ class PushRuleEvaluatorForEvent(object): return _glob_matches(pattern, haystack) - def _contains_display_name(self, display_name): + def _contains_display_name(self, display_name: str) -> bool: if not display_name: return False @@ -142,51 +150,52 @@ class PushRuleEvaluatorForEvent(object): if not body: return False - return _glob_matches(display_name, body, word_boundary=True) + # Similar to _glob_matches, but do not treat display_name as a glob. + r = regex_cache.get((display_name, False, True), None) + if not r: + r = re.escape(display_name) + r = _re_word_boundary(r) + r = re.compile(r, flags=re.IGNORECASE) + regex_cache[(display_name, False, True)] = r + + return r.search(body) - def _get_value(self, dotted_key): + def _get_value(self, dotted_key: str) -> str: return self._value_cache.get(dotted_key, None) -# Caches (glob, word_boundary) -> regex for push. See _glob_matches +# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) register_cache("cache", "regex_push_cache", regex_cache) -def _glob_matches(glob, value, word_boundary=False): +def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: """Tests if value matches glob. Args: - glob (string) - value (string): String to test against glob. - word_boundary (bool): Whether to match against word boundaries or entire + glob + value: String to test against glob. + word_boundary: Whether to match against word boundaries or entire string. Defaults to False. - - Returns: - bool """ try: - r = regex_cache.get((glob, word_boundary), None) + r = regex_cache.get((glob, True, word_boundary), None) if not r: r = _glob_to_re(glob, word_boundary) - regex_cache[(glob, word_boundary)] = r + regex_cache[(glob, True, word_boundary)] = r return r.search(value) except re.error: logger.warning("Failed to parse glob to regex: %r", glob) return False -def _glob_to_re(glob, word_boundary): +def _glob_to_re(glob: str, word_boundary: bool) -> Pattern: """Generates regex for a given glob. Args: - glob (string) - word_boundary (bool): Whether to match against word boundaries or entire - string. Defaults to False. - - Returns: - regex object + glob + word_boundary: Whether to match against word boundaries or entire string. """ if IS_GLOB.search(glob): r = re.escape(glob) @@ -219,7 +228,7 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) -def _re_word_boundary(r): +def _re_word_boundary(r: str) -> str: """ Adds word boundary characters to the start and end of an expression to require that the match occur as a whole word, diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py new file mode 100644 index 0000000000..9ae6a87d7b --- /dev/null +++ b/tests/push/test_push_rule_evaluator.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent +from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent + +from tests import unittest + + +class PushRuleEvaluatorTestCase(unittest.TestCase): + def setUp(self): + event = FrozenEvent( + { + "event_id": "$event_id", + "type": "m.room.history_visibility", + "sender": "@user:test", + "state_key": "", + "room_id": "@room:test", + "content": {"body": "foo bar baz"}, + }, + RoomVersions.V1, + ) + room_member_count = 0 + sender_power_level = 0 + power_levels = {} + self.evaluator = PushRuleEvaluatorForEvent( + event, room_member_count, sender_power_level, power_levels + ) + + def test_display_name(self): + """Check for a matching display name in the body of the event.""" + condition = { + "kind": "contains_display_name", + } + + # Blank names are skipped. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + + # Check a display name that doesn't match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + + # Check a display name which matches. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + + # A display name that matches, but not a full word does not result in a match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + + # A display name should not be interpreted as a regular expression. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + + # A display name with spaces should work fine. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) diff --git a/tox.ini b/tox.ini index 8e3f09e638..34d6322c4b 100644 --- a/tox.ini +++ b/tox.ini @@ -194,6 +194,7 @@ commands = mypy \ synapse/metrics \ synapse/module_api \ synapse/push/pusherpool.py \ + synapse/push/push_rule_evaluator.py \ synapse/replication \ synapse/rest \ synapse/spam_checker_api \ -- cgit 1.5.1