From 4c96ce396e900a94af66ec070af925881b6e1e24 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 12 Nov 2021 15:50:54 +0000 Subject: Misc typing fixes for `tests`, part 1 of N (#11323) * Annotate HomeserverTestCase.servlets * Correct annotation of federation_auth_origin * Use AnyStr custom_headers instead of a Union This allows (str, str) and (bytes, bytes). This disallows (str, bytes) and (bytes, str) * DomainSpecificString.SIGIL is a ClassVar --- tests/server.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'tests/server.py') diff --git a/tests/server.py b/tests/server.py index 103351b487..a7cc5cd325 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,7 +16,16 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union +from typing import ( + AnyStr, + Callable, + Dict, + Iterable, + MutableMapping, + Optional, + Tuple, + Union, +) import attr from typing_extensions import Deque @@ -222,9 +231,7 @@ def make_request( federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ -- cgit 1.5.1 From 0dda1a79687b8375dd5b23763ba1585e5dad030d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 16 Nov 2021 10:41:35 +0000 Subject: Misc typing fixes for tests, part 2 of N (#11330) --- changelog.d/11330.misc | 1 + tests/handlers/test_register.py | 9 +++++--- tests/rest/client/utils.py | 51 +++++++++++++++++++++++++++++++++-------- tests/server.py | 3 ++- tests/unittest.py | 31 ++++++++++++------------- 5 files changed, 66 insertions(+), 29 deletions(-) create mode 100644 changelog.d/11330.misc (limited to 'tests/server.py') diff --git a/changelog.d/11330.misc b/changelog.d/11330.misc new file mode 100644 index 0000000000..86f26543dd --- /dev/null +++ b/changelog.d/11330.misc @@ -0,0 +1 @@ +Improve type annotations in Synapse's test suite. diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index db691c4c1c..cd6f2c77ae 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - self.store.count_monthly_users = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.count_monthly_users = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - self.store.get_monthly_active_count = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.get_monthly_active_count = Mock( # type: ignore[assignment] return_value=make_awaitable(self.lots_of_users) ) self.get_failure( @@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.store.get_monthly_active_count = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.get_monthly_active_count = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 7cf782e2d6..1af5e5cee5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -28,11 +28,12 @@ from typing import ( MutableMapping, Optional, Tuple, - Union, + overload, ) from unittest.mock import patch import attr +from typing_extensions import Literal from twisted.web.resource import Resource from twisted.web.server import Site @@ -55,6 +56,32 @@ class RestHelper: site = attr.ib(type=Site) auth_user_id = attr.ib() + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: Literal[200] = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> str: + ... + + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: int = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> Optional[str]: + ... + def create_room_as( self, room_creator: Optional[str] = None, @@ -64,7 +91,7 @@ class RestHelper: expect_code: int = 200, extra_content: Optional[Dict] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ) -> str: + ) -> Optional[str]: """ Create a room. @@ -107,6 +134,8 @@ class RestHelper: if expect_code == 200: return channel.json_body["room_id"] + else: + return None def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( @@ -176,7 +205,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, - expect_errcode: str = None, + expect_errcode: Optional[str] = None, ) -> None: """ Send a membership state event into a room. @@ -260,9 +289,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -509,7 +536,7 @@ class RestHelper: went. """ - cookies = {} + cookies: Dict[str, str] = {} # if we're doing a ui auth, hit the ui auth redirect endpoint if ui_auth_session_id: @@ -631,7 +658,13 @@ class RestHelper: # hit the redirect url again with the right Host header, which should now issue # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] + def get_location(channel: FakeChannel) -> str: + location_values = channel.headers.getRawHeaders("Location") + # Keep mypy happy by asserting that location_values is nonempty + assert location_values + return location_values[0] + + location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( self.hs.get_reactor(), @@ -645,7 +678,7 @@ class RestHelper: assert channel.code == 302 channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] + return get_location(channel) def initiate_sso_ui_auth( self, ui_auth_session_id: str, cookies: MutableMapping[str, str] diff --git a/tests/server.py b/tests/server.py index a7cc5cd325..40cf5b12c3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -24,6 +24,7 @@ from typing import ( MutableMapping, Optional, Tuple, + Type, Union, ) @@ -226,7 +227,7 @@ def make_request( path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, - request: Request = SynapseRequest, + request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, diff --git a/tests/unittest.py b/tests/unittest.py index ba830618c2..c9a08a3420 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from twisted.web.resource import Resource +from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership @@ -95,16 +96,13 @@ def around(target): return _around -T = TypeVar("T") - - class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the root logger's logging level while that test (case|method) runs.""" - def __init__(self, methodName, *args, **kwargs): - super().__init__(methodName, *args, **kwargs) + def __init__(self, methodName: str): + super().__init__(methodName) method = getattr(self, methodName) @@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase): Attributes: servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. - hijack_auth (bool): Whether to hijack auth to return the user specified + hijack_auth: Whether to hijack auth to return the user specified in user_id. """ - hijack_auth = True - needs_threadpool = False + hijack_auth: ClassVar[bool] = True + needs_threadpool: ClassVar[bool] = False servlets: ClassVar[List[RegisterServletsFunc]] = [] - def __init__(self, methodName, *args, **kwargs): - super().__init__(methodName, *args, **kwargs) + def __init__(self, methodName: str): + super().__init__(methodName) # see if we have any additional config for this test method = getattr(self, methodName) @@ -301,9 +299,10 @@ class HomeserverTestCase(TestCase): None, ) - self.hs.get_auth().get_user_by_req = get_user_by_req - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token - self.hs.get_auth().get_access_token_from_request = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] return_value="1234" ) @@ -417,7 +416,7 @@ class HomeserverTestCase(TestCase): path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, + request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, @@ -596,7 +595,7 @@ class HomeserverTestCase(TestCase): nonce_str += b"\x00notadmin" want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) - want_mac = want_mac.hexdigest() + want_mac_digest = want_mac.hexdigest() body = json.dumps( { @@ -605,7 +604,7 @@ class HomeserverTestCase(TestCase): "displayname": displayname, "password": password, "admin": admin, - "mac": want_mac, + "mac": want_mac_digest, "inhibit_login": True, } ) -- cgit 1.5.1 From f7ec6e7d9e0dc360d9fb41f3a1afd7bdba1475c7 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 11:35:24 +0000 Subject: Convert one of the `setup_test_homeserver`s to `make_test_homeserver_synchronous` and pass in the homeserver rather than calling a same-named function to ask for one. Later commits will jiggle things around to make this sensible. --- tests/server.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) (limited to 'tests/server.py') diff --git a/tests/server.py b/tests/server.py index 40cf5b12c3..41eb3995bd 100644 --- a/tests/server.py +++ b/tests/server.py @@ -57,7 +57,6 @@ from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import Clock -from tests.utils import setup_test_homeserver as _sth logger = logging.getLogger(__name__) @@ -450,14 +449,11 @@ class ThreadPool: return d -def setup_test_homeserver(cleanup_func, *args, **kwargs): +def make_test_homeserver_synchronous(server: HomeServer) -> None: """ - Set up a synchronous test server, driven by the reactor used by - the homeserver. + Make the given test homeserver's database interactions synchronous. """ - server = _sth(cleanup_func, *args, **kwargs) - # Make the thread pool synchronous. clock = server.get_clock() for database in server.get_datastores().databases: @@ -485,6 +481,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction + # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) pool.running = True @@ -492,8 +489,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): # thread, so we need to disable the dedicated thread behaviour. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False - return server - def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() -- cgit 1.5.1 From b3fd99b74a3f6f42a9afd1b19ee4c60e38e8e91a Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 11:37:21 +0000 Subject: Move `tests.utils.setup_test_homeserver` to `tests.server` It had no users. We have just taken the identity of a previous function but don't provide the same behaviour, so we need to fix this in the next commit... --- tests/server.py | 185 ++++++++++++++++++++++++++++++++++++++- tests/storage/test_base.py | 3 +- tests/storage/test_roommember.py | 2 +- tests/utils.py | 175 +----------------------------------- 4 files changed, 188 insertions(+), 177 deletions(-) (limited to 'tests/server.py') diff --git a/tests/server.py b/tests/server.py index 41eb3995bd..017e5cf635 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import hashlib import json import logging +import time +import uuid +import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -27,6 +30,7 @@ from typing import ( Type, Union, ) +from unittest.mock import Mock import attr from typing_extensions import Deque @@ -53,10 +57,24 @@ from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site +from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.server import HomeServer +from synapse.storage import DataStore +from synapse.storage.engines import PostgresEngine, create_engine from synapse.types import JsonDict from synapse.util import Clock +from tests.utils import ( + LEAVE_DB, + POSTGRES_BASE_DB, + POSTGRES_HOST, + POSTGRES_PASSWORD, + POSTGRES_USER, + USE_POSTGRES_FOR_TESTS, + MockClock, + default_config, +) logger = logging.getLogger(__name__) @@ -668,3 +686,168 @@ def connect_client( client.makeConnection(FakeTransport(server, reactor)) return client, server + + +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + +def setup_test_homeserver( + cleanup_func, + name="test", + config=None, + reactor=None, + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs, +): + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. + """ + if reactor is None: + from twisted.internet import reactor + + if config is None: + config = default_config(name, parse=True) + + config.ldap_enabled = False + + if "clock" not in kwargs: + kwargs["clock"] = MockClock() + + if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + + database_config = { + "name": "psycopg2", + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + else: + database_config = { + "name": "sqlite3", + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + } + + if "db_txn_limit" in kwargs: + database_config["txn_limit"] = kwargs["db_txn_limit"] + + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) + + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + + hs = homeserver_to_use( + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, + ) + + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + + # bcrypt is far too slow to be doing in unit tests + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash + + return hs diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index ddad44bd6c..3e4f0579c9 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest -from tests.utils import TestHomeServer, default_config +from tests.server import TestHomeServer +from tests.utils import default_config class SQLBaseStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index fccab733c0..5cfdfe9b85 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -19,8 +19,8 @@ from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest +from tests.server import TestHomeServer from tests.test_utils import event_injection -from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 983859120f..6d013e8518 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,12 +14,7 @@ # limitations under the License. import atexit -import hashlib import os -import time -import uuid -import warnings -from typing import Type from unittest.mock import Mock, patch from urllib import parse as urlparse @@ -28,14 +23,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions -from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context -from synapse.server import HomeServer -from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import PostgresEngine, create_engine +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -182,171 +174,6 @@ def default_config(name, parse=False): return config_dict -class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore - - -def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, - homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): - """ - Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. - - If no datastore is supplied, one is created and given to the homeserver. - - Args: - cleanup_func : The function used to register a cleanup routine for - after the test. - - Calling this method directly is deprecated: you should instead derive from - HomeserverTestCase. - """ - if reactor is None: - from twisted.internet import reactor - - if config is None: - config = default_config(name, parse=True) - - config.ldap_enabled = False - - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - - if USE_POSTGRES_FOR_TESTS: - test_db = "synapse_test_%s" % uuid.uuid4().hex - - database_config = { - "name": "psycopg2", - "args": { - "database": test_db, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - else: - database_config = { - "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, - } - - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] - - database = DatabaseConnectionConfig("master", database_config) - config.database.databases = [database] - - db_engine = create_engine(database.config) - - # Create the database before we actually try and connect to it, based off - # the template database we generate in setupdb() - if isinstance(db_engine, PostgresEngine): - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - cur.execute( - "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) - ) - cur.close() - db_conn.close() - - hs = homeserver_to_use( - name, - config=config, - version_string="Synapse/tests", - reactor=reactor, - ) - - # Install @cache_in_self attributes - for key, val in kwargs.items(): - setattr(hs, "_" + key, val) - - # Mock TLS - hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() - - hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) - - # bcrypt is far too slow to be doing in unit tests - # Need to let the HS build an auth handler and then mess with it - # because AuthHandler's constructor requires the HS, so we can't make one - # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): - return hashlib.md5(p.encode("utf8")).hexdigest() - - hs.get_auth_handler().hash = hash - - async def validate_hash(p, h): - return hashlib.md5(p.encode("utf8")).hexdigest() == h - - hs.get_auth_handler().validate_hash = validate_hash - - return hs - - def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} -- cgit 1.5.1 From 7be88fbf48156b36b6daefb228e1258e7d48cae4 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 11:40:05 +0000 Subject: Give `tests.server.setup_test_homeserver` (nominally!) the same behaviour by calling into `make_test_homeserver_synchronous`. The function *could* have been inlined at this point but the function is big enough and it felt fine to leave it as is. At least there isn't a confusing name clash anymore! --- tests/server.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests/server.py') diff --git a/tests/server.py b/tests/server.py index 017e5cf635..b29df37595 100644 --- a/tests/server.py +++ b/tests/server.py @@ -850,4 +850,7 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash + # Make the threadpool and database transactions synchronous for testing. + make_test_homeserver_synchronous(hs) + return hs -- cgit 1.5.1 From 8cd68b8102eeab1b525712097c1b2e9679c11896 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 12:31:28 +0000 Subject: Revert accidental commits to develop. --- changelog.d/11503.misc | 1 - tests/server.py | 199 ++------------------------------------- tests/storage/test_base.py | 3 +- tests/storage/test_roommember.py | 2 +- tests/utils.py | 175 +++++++++++++++++++++++++++++++++- 5 files changed, 185 insertions(+), 195 deletions(-) delete mode 100644 changelog.d/11503.misc (limited to 'tests/server.py') diff --git a/changelog.d/11503.misc b/changelog.d/11503.misc deleted file mode 100644 index 03a24a9224..0000000000 --- a/changelog.d/11503.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. \ No newline at end of file diff --git a/tests/server.py b/tests/server.py index b29df37595..40cf5b12c3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hashlib + import json import logging -import time -import uuid -import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -30,7 +27,6 @@ from typing import ( Type, Union, ) -from unittest.mock import Mock import attr from typing_extensions import Deque @@ -57,24 +53,11 @@ from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site -from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest -from synapse.server import HomeServer -from synapse.storage import DataStore -from synapse.storage.engines import PostgresEngine, create_engine from synapse.types import JsonDict from synapse.util import Clock -from tests.utils import ( - LEAVE_DB, - POSTGRES_BASE_DB, - POSTGRES_HOST, - POSTGRES_PASSWORD, - POSTGRES_USER, - USE_POSTGRES_FOR_TESTS, - MockClock, - default_config, -) +from tests.utils import setup_test_homeserver as _sth logger = logging.getLogger(__name__) @@ -467,11 +450,14 @@ class ThreadPool: return d -def make_test_homeserver_synchronous(server: HomeServer) -> None: +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ - Make the given test homeserver's database interactions synchronous. + Set up a synchronous test server, driven by the reactor used by + the homeserver. """ + server = _sth(cleanup_func, *args, **kwargs) + # Make the thread pool synchronous. clock = server.get_clock() for database in server.get_datastores().databases: @@ -499,7 +485,6 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction - # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) pool.running = True @@ -507,6 +492,8 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: # thread, so we need to disable the dedicated thread behaviour. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server + def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() @@ -686,171 +673,3 @@ def connect_client( client.makeConnection(FakeTransport(server, reactor)) return client, server - - -class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore - - -def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, - homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): - """ - Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. - - If no datastore is supplied, one is created and given to the homeserver. - - Args: - cleanup_func : The function used to register a cleanup routine for - after the test. - - Calling this method directly is deprecated: you should instead derive from - HomeserverTestCase. - """ - if reactor is None: - from twisted.internet import reactor - - if config is None: - config = default_config(name, parse=True) - - config.ldap_enabled = False - - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - - if USE_POSTGRES_FOR_TESTS: - test_db = "synapse_test_%s" % uuid.uuid4().hex - - database_config = { - "name": "psycopg2", - "args": { - "database": test_db, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - else: - database_config = { - "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, - } - - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] - - database = DatabaseConnectionConfig("master", database_config) - config.database.databases = [database] - - db_engine = create_engine(database.config) - - # Create the database before we actually try and connect to it, based off - # the template database we generate in setupdb() - if isinstance(db_engine, PostgresEngine): - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - cur.execute( - "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) - ) - cur.close() - db_conn.close() - - hs = homeserver_to_use( - name, - config=config, - version_string="Synapse/tests", - reactor=reactor, - ) - - # Install @cache_in_self attributes - for key, val in kwargs.items(): - setattr(hs, "_" + key, val) - - # Mock TLS - hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() - - hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) - - # bcrypt is far too slow to be doing in unit tests - # Need to let the HS build an auth handler and then mess with it - # because AuthHandler's constructor requires the HS, so we can't make one - # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): - return hashlib.md5(p.encode("utf8")).hexdigest() - - hs.get_auth_handler().hash = hash - - async def validate_hash(p, h): - return hashlib.md5(p.encode("utf8")).hexdigest() == h - - hs.get_auth_handler().validate_hash = validate_hash - - # Make the threadpool and database transactions synchronous for testing. - make_test_homeserver_synchronous(hs) - - return hs diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c9..ddad44bd6c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -23,8 +23,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest -from tests.server import TestHomeServer -from tests.utils import default_config +from tests.utils import TestHomeServer, default_config class SQLBaseStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b85..fccab733c0 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -19,8 +19,8 @@ from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest -from tests.server import TestHomeServer from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 6d013e8518..983859120f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,12 @@ # limitations under the License. import atexit +import hashlib import os +import time +import uuid +import warnings +from typing import Type from unittest.mock import Mock, patch from urllib import parse as urlparse @@ -23,11 +28,14 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context +from synapse.server import HomeServer +from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -174,6 +182,171 @@ def default_config(name, parse=False): return config_dict +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + +def setup_test_homeserver( + cleanup_func, + name="test", + config=None, + reactor=None, + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs, +): + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. + """ + if reactor is None: + from twisted.internet import reactor + + if config is None: + config = default_config(name, parse=True) + + config.ldap_enabled = False + + if "clock" not in kwargs: + kwargs["clock"] = MockClock() + + if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + + database_config = { + "name": "psycopg2", + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + else: + database_config = { + "name": "sqlite3", + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + } + + if "db_txn_limit" in kwargs: + database_config["txn_limit"] = kwargs["db_txn_limit"] + + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) + + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + + hs = homeserver_to_use( + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, + ) + + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + + # bcrypt is far too slow to be doing in unit tests + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash + + return hs + + def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} -- cgit 1.5.1