summary refs log tree commit diff
path: root/synapse/handlers/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/_base.py')
-rw-r--r--synapse/handlers/_base.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..bd8e71ae56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 import synapse.state
 import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.types import UserID
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,11 +34,7 @@ class BaseHandler:
     Common base class for the event handlers.
     """
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
@@ -56,7 +56,7 @@ class BaseHandler:
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
-            )
+            )  # type: Optional[Ratelimiter]
         else:
             self.admin_redaction_ratelimiter = None
 
@@ -127,15 +127,15 @@ class BaseHandler:
             if guest_access != "can_join":
                 if context:
                     current_state_ids = await context.get_current_state_ids()
-                    current_state = await self.store.get_events(
+                    current_state_dict = await self.store.get_events(
                         list(current_state_ids.values())
                     )
+                    current_state = list(current_state_dict.values())
                 else:
-                    current_state = await self.state_handler.get_current_state(
+                    current_state_map = await self.state_handler.get_current_state(
                         event.room_id
                     )
-
-                current_state = list(current_state.values())
+                    current_state = list(current_state_map.values())
 
                 logger.info("maybe_kick_guest_users %r", current_state)
                 await self.kick_guest_users(current_state)