summary refs log tree commit diff
path: root/tests/handlers/test_directory.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_directory.py')
-rw-r--r--tests/handlers/test_directory.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 3b72c4c9d0..367d94eca3 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -13,27 +13,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.api.errors
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
 from synapse.rest.client import directory, login, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict, RoomAlias, create_requester
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the directory service."""
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.mock_federation = Mock()
+        self.mock_federation = AsyncMock()
         self.mock_registry = Mock()
 
         self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -72,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self) -> None:
-        self.mock_federation.make_query.return_value = make_awaitable(
-            {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
-        )
+        self.mock_federation.make_query.return_value = {
+            "room_id": "!8765qwer:test",
+            "servers": ["test", "remote"],
+        }
 
         result = self.get_success(self.handler.get_association(self.remote_room))
 
@@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
-    def _create_alias(self, user) -> None:
+    def _create_alias(self, user: str) -> None:
         # Create a new alias to this room.
         self.get_success(
             self.store.create_room_alias_association(
@@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
         return room_alias
 
-    def _set_canonical_alias(self, content) -> None:
+    def _set_canonical_alias(self, content: JsonDict) -> None:
         """Configure the canonical alias state on the room."""
         self.helper.send_state(
             self.room_id,
@@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
             tok=self.admin_user_tok,
         )
 
-    def _get_canonical_alias(self):
+    def _get_canonical_alias(self) -> EventBase:
         """Get the canonical alias state of the room."""
-        return self.get_success(
+        result = self.get_success(
             self._storage_controllers.state.get_current_state_event(
                 self.room_id, EventTypes.CanonicalAlias, ""
             )
         )
+        assert result is not None
+        return result
 
     def test_remove_alias(self) -> None:
         """Removing an alias that is the canonical alias should remove it there too."""
@@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
 
         data = self._get_canonical_alias()
-        self.assertEqual(data["content"]["alias"], self.test_alias)
-        self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+        self.assertEqual(data.content["alias"], self.test_alias)
+        self.assertEqual(data.content["alt_aliases"], [self.test_alias])
 
         # Finally, delete the alias.
         self.get_success(
@@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
 
         data = self._get_canonical_alias()
-        self.assertNotIn("alias", data["content"])
-        self.assertNotIn("alt_aliases", data["content"])
+        self.assertNotIn("alias", data.content)
+        self.assertNotIn("alt_aliases", data.content)
 
     def test_remove_other_alias(self) -> None:
         """Removing an alias listed as in alt_aliases should remove it there too."""
@@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
 
         data = self._get_canonical_alias()
-        self.assertEqual(data["content"]["alias"], self.test_alias)
+        self.assertEqual(data.content["alias"], self.test_alias)
         self.assertEqual(
-            data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
+            data.content["alt_aliases"], [self.test_alias, other_test_alias]
         )
 
         # Delete the second alias.
@@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
 
         data = self._get_canonical_alias()
-        self.assertEqual(data["content"]["alias"], self.test_alias)
-        self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+        self.assertEqual(data.content["alias"], self.test_alias)
+        self.assertEqual(data.content["alt_aliases"], [self.test_alias])
 
 
 class TestCreateAliasACL(unittest.HomeserverTestCase):