summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py73
1 files changed, 36 insertions, 37 deletions
diff --git a/tests/unittest.py b/tests/unittest.py

index bec4a3d023..a120c2976c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: assert self.helper.auth_user_id is not None + token = "some_fake_token" # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, - "some_fake_token", + token, None, None, ) ) - async def get_user_by_access_token( - token: Optional[str] = None, allow_guest: bool = False - ) -> JsonDict: - assert self.helper.auth_user_id is not None - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": token_id, - "is_guest": False, - } - - async def get_user_by_req( - request: SynapseRequest, - allow_guest: bool = False, - allow_expired: bool = False, - ) -> Requester: + # This has to be a function and not just a Mock, because + # `self.helper.auth_user_id` is temporarily reassigned in some tests + async def get_requester(*args, **kwargs) -> Requester: assert self.helper.auth_user_id is not None return create_requester( - UserID.from_string(self.helper.auth_user_id), - token_id, - False, - False, - None, + user_id=UserID.from_string(self.helper.auth_user_id), + access_token_id=token_id, ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] - return_value="1234" - ) + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] @@ -376,13 +360,13 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock): """ Make and return a homeserver. Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. + clock: The Clock, associated with the reactor. Returns: A homeserver suitable for testing. @@ -442,9 +426,8 @@ class HomeserverTestCase(TestCase): Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. - homeserver (synapse.server.HomeServer): The HomeServer to test - against. + clock: The Clock, associated with the reactor. + homeserver: The HomeServer to test against. Function to optionally be overridden in subclasses. """ @@ -468,11 +451,10 @@ class HomeserverTestCase(TestCase): given content. Args: - method (bytes/unicode): The HTTP request method ("verb"). - path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. - escaped UTF-8 & spaces and such). - content (bytes or dict): The body of the request. JSON-encoded, if - a dict. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces + and such). content (bytes or dict): The body of the request. + JSON-encoded, if a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake @@ -677,14 +659,29 @@ class HomeserverTestCase(TestCase): username: str, password: str, device_id: Optional[str] = None, + additional_request_fields: Optional[Dict[str, str]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None, ) -> str: """ Log in a user, and get an access token. Requires the Login API be registered. + + Args: + username: The localpart to assign to the new user. + password: The password to assign to the new user. + device_id: An optional device ID to assign to the new device created during + login. + additional_request_fields: A dictionary containing any additional /login + request fields and their values. + custom_headers: Custom HTTP headers and values to add to the /login request. + + Returns: + The newly registered user's Matrix ID. """ body = {"type": "m.login.password", "user": username, "password": password} if device_id: body["device_id"] = device_id + if additional_request_fields: + body.update(additional_request_fields) channel = self.make_request( "POST", @@ -735,7 +732,9 @@ class HomeserverTestCase(TestCase): event.internal_metadata.soft_failed = True self.get_success( - event_creator.handle_new_client_event(requester, event, context) + event_creator.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) return event.event_id