diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1b91cf5eaa..e977ed1044 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
DefaultDict,
Dict,
FrozenSet,
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap
+from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events(
self,
room_id: str,
- event_ids: Collection[str],
+ event_ids: StrCollection,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
@@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
- self, room_id: str, latest_event_ids: Collection[str]
+ self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
- self, room_id: str, event_ids: Collection[str]
+ self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids
@@ -470,7 +469,7 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
- self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
+ self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
- self, event_ids: Collection[str], allow_rejected: bool = False
+ self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
|