diff options
Diffstat (limited to 'tests/handlers/test_directory.py')
-rw-r--r-- | tests/handlers/test_directory.py | 39 |
1 files changed, 21 insertions, 18 deletions
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 3b72c4c9d0..367d94eca3 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -13,27 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.rest.client import directory, login, room from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, create_requester from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = Mock() + self.mock_federation = AsyncMock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -72,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self) -> None: - self.mock_federation.make_query.return_value = make_awaitable( - {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} - ) + self.mock_federation.make_query.return_value = { + "room_id": "!8765qwer:test", + "servers": ["test", "remote"], + } result = self.get_success(self.handler.get_association(self.remote_room)) @@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user) -> None: + def _create_alias(self, user: str) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) return room_alias - def _set_canonical_alias(self, content) -> None: + def _set_canonical_alias(self, content: JsonDict) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def _get_canonical_alias(self): + def _get_canonical_alias(self) -> EventBase: """Get the canonical alias state of the room.""" - return self.get_success( + result = self.get_success( self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) + assert result is not None + return result def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" @@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) - self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + self.assertEqual(data.content["alias"], self.test_alias) + self.assertEqual(data.content["alt_aliases"], [self.test_alias]) # Finally, delete the alias. self.get_success( @@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertNotIn("alias", data["content"]) - self.assertNotIn("alt_aliases", data["content"]) + self.assertNotIn("alias", data.content) + self.assertNotIn("alt_aliases", data.content) def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" @@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data.content["alias"], self.test_alias) self.assertEqual( - data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + data.content["alt_aliases"], [self.test_alias, other_test_alias] ) # Delete the second alias. @@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) - self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + self.assertEqual(data.content["alias"], self.test_alias) + self.assertEqual(data.content["alt_aliases"], [self.test_alias]) class TestCreateAliasACL(unittest.HomeserverTestCase): |