summary refs log tree commit diff
path: root/tests/replication/test_account_validity.py
diff options
context:
space:
mode:
authorMathieu Velten <matmaul@gmail.com>2022-07-21 09:59:27 +0200
committerMathieu Velten <mathieuv@matrix.org>2022-08-12 16:20:12 +0200
commit13113d83781062cb5f0564aef1e6049464150455 (patch)
tree48b0b1e0c33299293fca573d7d2a05cbcbd85466 /tests/replication/test_account_validity.py
parentAdd some miscellaneous comments around sync (#13474) (diff)
downloadsynapse-13113d83781062cb5f0564aef1e6049464150455.tar.xz
Add test for account validity feature
Diffstat (limited to 'tests/replication/test_account_validity.py')
-rw-r--r--tests/replication/test_account_validity.py238
1 files changed, 238 insertions, 0 deletions
diff --git a/tests/replication/test_account_validity.py b/tests/replication/test_account_validity.py
new file mode 100644

index 0000000000..408eb266b9 --- /dev/null +++ b/tests/replication/test_account_validity.py
@@ -0,0 +1,238 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Optional, cast + +from twisted.internet.defer import ensureDeferred + +import synapse +from synapse.module_api import DatabasePool, LoggingTransaction, ModuleApi, cached +from synapse.server import HomeServer + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import ThreadedMemoryReactorClock, make_request + +logger = logging.getLogger(__name__) + + +class MockAccountValidityStore: + def __init__( + self, + api: ModuleApi, + ): + self._api = api + + async def create_db(self): + def create_table_txn(txn: LoggingTransaction): + txn.execute( + """ + CREATE TABLE IF NOT EXISTS mock_account_validity( + user_id TEXT PRIMARY KEY, + expired BOOLEAN NOT NULL + ) + """, + (), + ) + + await self._api.run_db_interaction( + "account_validity_create_table", + create_table_txn, + ) + + @cached() + async def is_user_expired(self, user_id: str) -> Optional[bool]: + def get_expiration_for_user_txn(txn: LoggingTransaction): + return DatabasePool.simple_select_one_onecol_txn( + txn=txn, + table="mock_account_validity", + keyvalues={"user_id": user_id}, + retcol="expired", + allow_none=True, + ) + + return await self._api.run_db_interaction( + "get_expiration_for_user", + get_expiration_for_user_txn, + ) + + async def on_user_registration(self, user_id: str) -> None: + def add_valid_user_txn(txn: LoggingTransaction): + txn.execute( + "INSERT INTO mock_account_validity (user_id, expired) VALUES (?, ?)", + (user_id, False), + ) + + await self._api.run_db_interaction( + "account_validity_add_valid_user", + add_valid_user_txn, + ) + + async def set_expired(self, user_id: str, expired: bool = True) -> None: + def set_expired_user_txn(txn: LoggingTransaction): + txn.execute( + "UPDATE mock_account_validity SET expired = ? WHERE user_id = ?", + ( + expired, + user_id, + ), + ) + + txn.call_after(self.is_user_expired.invalidate, (user_id,)) + + await self._api.run_db_interaction( + "account_validity_set_expired_user", + set_expired_user_txn, + ) + + +class MockAccountValidity: + def __init__( + self, + config, + api: ModuleApi, + ): + self._api = api + + self._store = MockAccountValidityStore(api) + + ensureDeferred(self._store.create_db()) + cast(ThreadedMemoryReactorClock, api._hs.get_reactor()).pump([0.0]) + + self._api.register_account_validity_callbacks( + is_user_expired=self.is_user_expired, + on_user_registration=self.on_user_registration, + ) + + async def is_user_expired(self, user_id: str) -> Optional[bool]: + return await self._store.is_user_expired(user_id) + + async def on_user_registration(self, user_id: str) -> None: + await self._store.on_user_registration(user_id) + + +class WorkerAccountValidityTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.account.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.register.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + config["modules"] = [ + { + "module": __name__ + ".MockAccountValidity", + } + ] + + return config + + def make_homeserver(self, reactor, clock): + hs = super().make_homeserver(reactor, clock) + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + self.module = module(config=config, api=module_api) + logger.info("Loaded module %s", self.module) + return hs + + def make_worker_hs( + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs + ) -> HomeServer: + hs = super().make_worker_hs(worker_app, extra_config=extra_config) + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + # Do not store the module in self here since we want to expire the user + # from the main worker and see if it get properly replicated to the other one. + module(config=config, api=module_api) + logger.info("Loaded module %s", self.module) + return hs + + def _create_and_check_user(self): + self.register_user("user", "pass") + user_id = "@user:test" + token = self.login("user", "pass") + + channel = self.make_request( + "GET", + "/_matrix/client/v3/account/whoami", + access_token=token, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["user_id"], user_id) + + return user_id, token + + def test_account_validity(self): + user_id, token = self._create_and_check_user() + + self.get_success_or_raise(self.module._store.set_expired(user_id)) + + channel = self.make_request( + "GET", + "/_matrix/client/v3/account/whoami", + access_token=token, + ) + self.assertEqual(channel.code, 403) + + self.get_success_or_raise(self.module._store.set_expired(user_id, False)) + + channel = self.make_request( + "GET", + "/_matrix/client/v3/account/whoami", + access_token=token, + ) + self.assertEqual(channel.code, 200) + + def test_account_validity_with_worker_and_cache(self): + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + worker_site = self._hs_to_site[worker_hs] + + user_id, token = self._create_and_check_user() + + # check than the user is valid on the other worker too + channel = make_request( + self.reactor, + worker_site, + "GET", + "/_matrix/client/v3/account/whoami", + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # Expires user on the main worker, and check its state on the other worker + self.get_success_or_raise(self.module._store.set_expired(user_id)) + + channel = make_request( + self.reactor, + worker_site, + "GET", + "/_matrix/client/v3/account/whoami", + access_token=token, + ) + self.assertEqual(channel.code, 403) + + # Un-expires user on the main worker, and check its state on the other worker + self.get_success_or_raise(self.module._store.set_expired(user_id, False)) + + channel = make_request( + self.reactor, + worker_site, + "GET", + "/_matrix/client/v3/account/whoami", + access_token=token, + ) + self.assertEqual(channel.code, 200)