summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py13
-rw-r--r--synapse/state/v1.py5
-rw-r--r--synapse/state/v2.py7
3 files changed, 11 insertions, 14 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1b91cf5eaa..e977ed1044 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -20,7 +20,6 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     DefaultDict,
     Dict,
     FrozenSet,
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap
+from synapse.types import StateMap, StrCollection
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -197,7 +196,7 @@ class StateHandler:
     async def compute_state_after_events(
         self,
         room_id: str,
-        event_ids: Collection[str],
+        event_ids: StrCollection,
         state_filter: Optional[StateFilter] = None,
         await_full_state: bool = True,
     ) -> StateMap[str]:
@@ -231,7 +230,7 @@ class StateHandler:
         return await ret.get_state(self._state_storage_controller, state_filter)
 
     async def get_current_user_ids_in_room(
-        self, room_id: str, latest_event_ids: Collection[str]
+        self, room_id: str, latest_event_ids: StrCollection
     ) -> Set[str]:
         """
         Get the users IDs who are currently in a room.
@@ -256,7 +255,7 @@ class StateHandler:
         return await self.store.get_joined_user_ids_from_state(room_id, state)
 
     async def get_hosts_in_room_at_events(
-        self, room_id: str, event_ids: Collection[str]
+        self, room_id: str, event_ids: StrCollection
     ) -> FrozenSet[str]:
         """Get the hosts that were in a room at the given event ids
 
@@ -470,7 +469,7 @@ class StateHandler:
     @trace
     @measure_func()
     async def resolve_state_groups_for_events(
-        self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
+        self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
     ) -> _StateCacheEntry:
         """Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
@@ -882,7 +881,7 @@ class StateResolutionStore:
     store: "DataStore"
 
     def get_events(
-        self, event_ids: Collection[str], allow_rejected: bool = False
+        self, event_ids: StrCollection, allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 500e384695..c76a2f082e 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,7 +17,6 @@ import logging
 from typing import (
     Awaitable,
     Callable,
-    Collection,
     Dict,
     Iterable,
     List,
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +44,7 @@ async def resolve_events_with_store(
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
+    state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
 ) -> StateMap[str]:
     """
     Args:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 44c49274a9..1752f95db8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -19,7 +19,6 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     Dict,
     Generator,
     Iterable,
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
     # This is usually synapse.state.StateResolutionStore, but it's replaced with a
     # TestStateResolutionStore in tests.
     def get_events(
-        self, event_ids: Collection[str], allow_rejected: bool = False
+        self, event_ids: StrCollection, allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         ...
 
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        auth_difference_unpersisted_part: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: StrCollection = union - intersection
     else:
         auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]