summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8183.misc1
-rw-r--r--synapse/handlers/federation.py20
-rw-r--r--synapse/handlers/room.py3
-rw-r--r--synapse/handlers/sync.py5
-rw-r--r--synapse/state/__init__.py32
-rw-r--r--synapse/state/v1.py10
-rw-r--r--synapse/state/v2.py6
-rw-r--r--synapse/types.py7
8 files changed, 52 insertions, 32 deletions
diff --git a/changelog.d/8183.misc b/changelog.d/8183.misc
new file mode 100644
index 0000000000..78d8834328
--- /dev/null
+++ b/changelog.d/8183.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.state`.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f8b234cee2..155d087413 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    MutableStateMap,
+    StateMap,
+    UserID,
+    get_domain_from_id,
+)
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -96,7 +102,7 @@ class _NewEventInfo:
 
     event = attr.ib(type=EventBase)
     state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
-    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
 
 
 class FederationHandler(BaseHandler):
@@ -2053,7 +2059,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-        auth_events: Optional[StateMap[EventBase]],
+        auth_events: Optional[MutableStateMap[EventBase]],
         backfilled: bool,
     ) -> EventContext:
         context = await self.state_handler.compute_event_context(event, old_state=state)
@@ -2137,7 +2143,9 @@ class FederationHandler(BaseHandler):
             current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
             )
-            current_state_ids = {k: e.event_id for k, e in current_states.items()}
+            current_state_ids = {
+                k: e.event_id for k, e in current_states.items()
+            }  # type: StateMap[str]
         else:
             current_state_ids = await self.state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -2223,7 +2231,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: StateMap[EventBase],
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """
 
@@ -2274,7 +2282,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: StateMap[EventBase],
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """Helper for do_auth. See there for docs.
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 236a37f777..1419d72e94 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -41,6 +41,7 @@ from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
+    MutableStateMap,
     Requester,
     RoomAlias,
     RoomID,
@@ -814,7 +815,7 @@ class RoomCreationHandler(BaseHandler):
         room_id: str,
         preset_config: str,
         invite_list: List[str],
-        initial_state: StateMap,
+        initial_state: MutableStateMap,
         creation_content: JsonDict,
         room_alias: Optional[RoomAlias] = None,
         power_level_content_override: Optional[JsonDict] = None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c42dac18f5..9a86eb01c9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import (
     Collection,
     JsonDict,
+    MutableStateMap,
     RoomStreamToken,
     StateMap,
     StreamToken,
@@ -588,7 +589,7 @@ class SyncHandler(object):
         room_id: str,
         sync_config: SyncConfig,
         batch: TimelineBatch,
-        state: StateMap[EventBase],
+        state: MutableStateMap[EventBase],
         now_token: StreamToken,
     ) -> Optional[JsonDict]:
         """ Works out a room summary block for this room, summarising the number
@@ -736,7 +737,7 @@ class SyncHandler(object):
         since_token: Optional[StreamToken],
         now_token: StreamToken,
         full_state: bool,
-    ) -> StateMap[EventBase]:
+    ) -> MutableStateMap[EventBase]:
         """ Works out the difference in state between the start of the timeline
         and the previous sync.
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a601303fa3..9bf2ec368f 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -25,6 +25,7 @@ from typing import (
     Sequence,
     Set,
     Union,
+    cast,
     overload,
 )
 
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -205,7 +206,7 @@ class StateHandler(object):
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return dict(ret.state)
+        return ret.state
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -302,7 +303,7 @@ class StateHandler(object):
             # if we're given the state before the event, then we use that
             state_ids_before_event = {
                 (s.type, s.state_key): s.event_id for s in old_state
-            }
+            }  # type: StateMap[str]
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
@@ -315,7 +316,7 @@ class StateHandler(object):
                 event.room_id, event.prev_event_ids()
             )
 
-            state_ids_before_event = dict(entry.state)
+            state_ids_before_event = entry.state
             state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
@@ -540,7 +541,7 @@ class StateResolutionHandler(object):
             #
             # XXX: is this actually worthwhile, or should we just let
             # resolve_events_with_store do it?
-            new_state = {}
+            new_state = {}  # type: MutableStateMap[str]
             conflicted_state = False
             for st in state_groups_ids.values():
                 for key, e_id in st.items():
@@ -554,13 +555,20 @@ class StateResolutionHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = await resolve_events_with_store(
-                        self.clock,
-                        room_id,
-                        room_version,
-                        list(state_groups_ids.values()),
-                        event_map=event_map,
-                        state_res_store=state_res_store,
+                    # resolve_events_with_store returns a StateMap, but we can
+                    # treat it as a MutableStateMap as it is above. It isn't
+                    # actually mutated anymore (and is frozen in
+                    # _make_state_cache_entry below).
+                    new_state = cast(
+                        MutableStateMap,
+                        await resolve_events_with_store(
+                            self.clock,
+                            room_id,
+                            room_version,
+                            list(state_groups_ids.values()),
+                            event_map=event_map,
+                            state_res_store=state_res_store,
+                        ),
                     )
 
             # if the new state matches any of the input state groups, we can
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 0eb7fdd9e5..a493279cbd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -32,7 +32,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -131,7 +131,7 @@ async def resolve_events_with_store(
 
 def _seperate(
     state_sets: Iterable[StateMap[str]],
-) -> Tuple[StateMap[str], StateMap[Set[str]]]:
+) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
@@ -152,7 +152,7 @@ def _seperate(
     """
     state_set_iterator = iter(state_sets)
     unconflicted_state = dict(next(state_set_iterator))
-    conflicted_state = {}  # type: StateMap[Set[str]]
+    conflicted_state = {}  # type: MutableStateMap[Set[str]]
 
     for state_set in state_set_iterator:
         for key, value in state_set.items():
@@ -208,7 +208,7 @@ def _create_auth_events_from_maps(
 
 
 def _resolve_with_state(
-    unconflicted_state_ids: StateMap[str],
+    unconflicted_state_ids: MutableStateMap[str],
     conflicted_state_ids: StateMap[Set[str]],
     auth_event_ids: StateMap[str],
     state_map: Dict[str, EventBase],
@@ -241,7 +241,7 @@ def _resolve_with_state(
 
 
 def _resolve_state_events(
-    conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
+    conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
 ) -> StateMap[EventBase]:
     """ This is where we actually decide which of the conflicted state to
     use.
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 0e9ffbd6e6..edf94e7ad6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -414,7 +414,7 @@ async def _iterative_auth_checks(
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
     state_res_store: "synapse.state.StateResolutionStore",
-) -> StateMap[str]:
+) -> MutableStateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
@@ -430,7 +430,7 @@ async def _iterative_auth_checks(
     Returns:
         Returns the final updated state
     """
-    resolved_state = base_state.copy()
+    resolved_state = dict(base_state)
     room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
     for idx, event_id in enumerate(event_ids, start=1):
diff --git a/synapse/types.py b/synapse/types.py
index bc36cdde30..f8b9b03850 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,7 @@ import re
 import string
 import sys
 from collections import namedtuple
-from typing import Any, Dict, Tuple, Type, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -41,8 +41,9 @@ else:
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
 T = TypeVar("T")
-StateMap = Dict[Tuple[str, str], T]
-
+StateKey = Tuple[str, str]
+StateMap = Mapping[StateKey, T]
+MutableStateMap = MutableMapping[StateKey, T]
 
 # the type of a JSON-serialisable dict. This could be made stronger, but it will
 # do for now.