diff options
-rw-r--r-- | changelog.d/8183.misc | 1 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 20 | ||||
-rw-r--r-- | synapse/handlers/room.py | 3 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 5 | ||||
-rw-r--r-- | synapse/state/__init__.py | 32 | ||||
-rw-r--r-- | synapse/state/v1.py | 10 | ||||
-rw-r--r-- | synapse/state/v2.py | 6 | ||||
-rw-r--r-- | synapse/types.py | 7 |
8 files changed, 52 insertions, 32 deletions
diff --git a/changelog.d/8183.misc b/changelog.d/8183.misc new file mode 100644 index 0000000000..78d8834328 --- /dev/null +++ b/changelog.d/8183.misc @@ -0,0 +1 @@ +Add type hints to `synapse.state`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f8b234cee2..155d087413 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,7 +72,13 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -96,7 +102,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) + auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) class FederationHandler(BaseHandler): @@ -2053,7 +2059,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[StateMap[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], backfilled: bool, ) -> EventContext: context = await self.state_handler.compute_event_context(event, old_state=state) @@ -2137,7 +2143,9 @@ class FederationHandler(BaseHandler): current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_states.items()} + current_state_ids = { + k: e.event_id for k, e in current_states.items() + } # type: StateMap[str] else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2223,7 +2231,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """ @@ -2274,7 +2282,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """Helper for do_auth. See there for docs. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 236a37f777..1419d72e94 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -41,6 +41,7 @@ from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, + MutableStateMap, Requester, RoomAlias, RoomID, @@ -814,7 +815,7 @@ class RoomCreationHandler(BaseHandler): room_id: str, preset_config: str, invite_list: List[str], - initial_state: StateMap, + initial_state: MutableStateMap, creation_content: JsonDict, room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c42dac18f5..9a86eb01c9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter from synapse.types import ( Collection, JsonDict, + MutableStateMap, RoomStreamToken, StateMap, StreamToken, @@ -588,7 +589,7 @@ class SyncHandler(object): room_id: str, sync_config: SyncConfig, batch: TimelineBatch, - state: StateMap[EventBase], + state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number @@ -736,7 +737,7 @@ class SyncHandler(object): since_token: Optional[StreamToken], now_token: StreamToken, full_state: bool, - ) -> StateMap[EventBase]: + ) -> MutableStateMap[EventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a601303fa3..9bf2ec368f 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -25,6 +25,7 @@ from typing import ( Sequence, Set, Union, + cast, overload, ) @@ -41,7 +42,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -205,7 +206,7 @@ class StateHandler(object): logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return dict(ret.state) + return ret.state async def get_current_users_in_room( self, room_id: str, latest_event_ids: Optional[List[str]] = None @@ -302,7 +303,7 @@ class StateHandler(object): # if we're given the state before the event, then we use that state_ids_before_event = { (s.type, s.state_key): s.event_id for s in old_state - } + } # type: StateMap[str] state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -315,7 +316,7 @@ class StateHandler(object): event.room_id, event.prev_event_ids() ) - state_ids_before_event = dict(entry.state) + state_ids_before_event = entry.state state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids @@ -540,7 +541,7 @@ class StateResolutionHandler(object): # # XXX: is this actually worthwhile, or should we just let # resolve_events_with_store do it? - new_state = {} + new_state = {} # type: MutableStateMap[str] conflicted_state = False for st in state_groups_ids.values(): for key, e_id in st.items(): @@ -554,13 +555,20 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, + # resolve_events_with_store returns a StateMap, but we can + # treat it as a MutableStateMap as it is above. It isn't + # actually mutated anymore (and is frozen in + # _make_state_cache_entry below). + new_state = cast( + MutableStateMap, + await resolve_events_with_store( + self.clock, + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ), ) # if the new state matches any of the input state groups, we can diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 0eb7fdd9e5..a493279cbd 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -32,7 +32,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -131,7 +131,7 @@ async def resolve_events_with_store( def _seperate( state_sets: Iterable[StateMap[str]], -) -> Tuple[StateMap[str], StateMap[Set[str]]]: +) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. @@ -152,7 +152,7 @@ def _seperate( """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} # type: StateMap[Set[str]] + conflicted_state = {} # type: MutableStateMap[Set[str]] for state_set in state_set_iterator: for key, value in state_set.items(): @@ -208,7 +208,7 @@ def _create_auth_events_from_maps( def _resolve_with_state( - unconflicted_state_ids: StateMap[str], + unconflicted_state_ids: MutableStateMap[str], conflicted_state_ids: StateMap[Set[str]], auth_event_ids: StateMap[str], state_map: Dict[str, EventBase], @@ -241,7 +241,7 @@ def _resolve_with_state( def _resolve_state_events( - conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] + conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] ) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 0e9ffbd6e6..edf94e7ad6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -414,7 +414,7 @@ async def _iterative_auth_checks( base_state: StateMap[str], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", -) -> StateMap[str]: +) -> MutableStateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -430,7 +430,7 @@ async def _iterative_auth_checks( Returns: Returns the final updated state """ - resolved_state = base_state.copy() + resolved_state = dict(base_state) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for idx, event_id in enumerate(event_ids, start=1): diff --git a/synapse/types.py b/synapse/types.py index bc36cdde30..f8b9b03850 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,7 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -41,8 +41,9 @@ else: # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") -StateMap = Dict[Tuple[str, str], T] - +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] # the type of a JSON-serialisable dict. This could be made stronger, but it will # do for now. |