summary refs log tree commit diff
path: root/tests/handlers/test_federation_event.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_federation_event.py')
-rw-r--r--tests/handlers/test_federation_event.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index e448cb1901..70ea4d15d4 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -14,6 +14,8 @@
 from typing import Optional
 from unittest import mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.errors import AuthError, StoreError
 from synapse.api.room_versions import RoomVersion
 from synapse.event_auth import (
@@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
 from synapse.logging.context import LoggingContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
 from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import event_injection, make_awaitable
@@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # mock out the federation transport client
         self.mock_federation_transport_client = mock.Mock(
             spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
@@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             )
         else:
 
-            async def get_event(destination: str, event_id: str, timeout=None):
+            async def get_event(
+                destination: str, event_id: str, timeout: Optional[int] = None
+            ) -> JsonDict:
                 self.assertEqual(destination, self.OTHER_SERVER_NAME)
                 self.assertEqual(event_id, prev_event.event_id)
                 return {"pdus": [prev_event.get_pdu_json()]}