diff --git a/changelog.d/6715.misc b/changelog.d/6715.misc
new file mode 100644
index 0000000000..8876b0446d
--- /dev/null
+++ b/changelog.d/6715.misc
@@ -0,0 +1 @@
+Add StateMap type alias to simplify types.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index abbc7079a3..2cbfab2569 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,7 +14,6 @@
# limitations under the License.
import logging
-from typing import Dict, Tuple
from six import itervalues
@@ -35,7 +34,7 @@ from synapse.api.errors import (
ResourceLimitError,
)
from synapse.config.server import is_threepid_reserved
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -509,10 +508,7 @@ class Auth(object):
return self.store.is_server_admin(user)
def compute_auth_events(
- self,
- event,
- current_state_ids: Dict[Tuple[str, str], str],
- for_verification: bool = False,
+ self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
"""Given an event and current state return the list of event IDs used
to auth an event.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a44baea365..9ea85e93e6 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Union
from six import iteritems
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import StateMap
@attr.s(slots=True)
@@ -106,13 +107,11 @@ class EventContext:
_state_group = attr.ib(default=None, type=Optional[int])
state_group_before_event = attr.ib(default=None, type=Optional[int])
prev_group = attr.ib(default=None, type=Optional[int])
- delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+ delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
app_service = attr.ib(default=None, type=Optional[ApplicationService])
- _current_state_ids = attr.ib(
- default=None, type=Optional[Dict[Tuple[str, str], str]]
- )
- _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+ _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+ _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
@staticmethod
def with_state(
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index a5b36b1827..5012aaea35 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
+from synapse.types import StateMap
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver.
@@ -77,7 +78,7 @@ class PerDestinationQueue(object):
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
- self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
+ self._pending_edus_keyed = {} # type: StateMap[Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a9407553b4..60a7c938bc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,9 +14,11 @@
# limitations under the License.
import logging
+from typing import List
from synapse.api.constants import Membership
-from synapse.types import RoomStreamToken
+from synapse.events import FrozenEvent
+from synapse.types import RoomStreamToken, StateMap
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -259,35 +261,26 @@ class ExfiltrationWriter(object):
"""Interface used to specify how to write exported data.
"""
- def write_events(self, room_id, events):
+ def write_events(self, room_id: str, events: List[FrozenEvent]):
"""Write a batch of events for a room.
-
- Args:
- room_id (str)
- events (list[FrozenEvent])
"""
pass
- def write_state(self, room_id, event_id, state):
+ def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
-
- Args:
- room_id (str)
- event_id (str)
- state (dict[tuple[str, str], FrozenEvent])
"""
pass
- def write_invite(self, room_id, event, state):
+ def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
"""Write an invite for the room, with associated invite state.
Args:
- room_id (str)
- event (FrozenEvent)
- state (dict[tuple[str, str], dict]): A subset of the state at the
+ room_id
+ event
+ state: A subset of the state at the
invite, with a subset of the event keys (type, state_key
content and sender)
"""
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 61b6713c88..d4f9a792fc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -89,7 +89,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
- auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
def shortstr(iterable, maxitems=5):
@@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(
- ours.values()
- ) # type: list[dict[tuple[str, str], str]]
+ state_maps = list(ours.values()) # type: list[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+ auth_events: Optional[StateMap[EventBase]],
backfilled: bool,
):
"""
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9cab2adbfb..9f50196ea7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ Requester,
+ RoomAlias,
+ RoomID,
+ RoomStreamToken,
+ StateMap,
+ StreamToken,
+ UserID,
+)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
@@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _update_upgraded_room_pls(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self,
+ requester: Requester,
+ old_room_id: str,
+ new_room_id: str,
+ old_room_state: StateMap[str],
):
"""Send updated power levels in both rooms after an upgrade
Args:
- requester (synapse.types.Requester): the user requesting the upgrade
- old_room_id (str): the id of the room to be replaced
- new_room_id (str): the id of the replacement room
- old_room_state (dict[tuple[str, str], str]): the state map for the old room
+ requester: the user requesting the upgrade
+ old_room_id: the id of the room to be replaced
+ new_room_id: the id of the replacement room
+ old_room_state: the state map for the old room
Returns:
Deferred
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 5accc071ab..cacd0c0c2b 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional
from six import iteritems, itervalues
@@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
room_id: str,
room_version: str,
- state_sets: List[Dict[Tuple[str, str], str]],
+ state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index b2f9865f39..d6c34ce3b7 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,7 +15,7 @@
import hashlib
import logging
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional
from six import iteritems, iterkeys, itervalues
@@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
def resolve_events_with_store(
room_id: str,
- state_sets: List[Dict[Tuple[str, str], str]],
+ state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable,
):
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 72fb8a6317..6216fdd204 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,7 @@
import heapq
import itertools
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
from six import iteritems, itervalues
@@ -27,6 +27,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
def resolve_events_with_store(
room_id: str,
room_version: str,
- state_sets: List[Dict[Tuple[str, str], str]],
+ state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
@@ -393,12 +394,12 @@ def _iterative_auth_checks(
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
- base_state (dict[tuple[str, str], str]): The set of state to start with
+ base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
- Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+ Deferred[StateMap[str]]: Returns the final updated state
"""
resolved_state = base_state.copy()
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d07440e3ed..33bebd1c48 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+ def get_filtered_current_state_ids(
+ self, room_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
- room_id (str)
- state_filter (StateFilter): The state filter used to fetch state
+ room_id
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
- event ID.
+ defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index d53695f238..c4ee9b7ccb 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
from six import iteritems
from six.moves import range
@@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
+ """Returns the state groups for a given set of groups from the
+ database, filtering on types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
results = {}
@@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
- groups (iterable[int]): list of state groups for which we want
+ groups: list of state groups for which we want
to get the state.
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+ def _get_state_for_groups_using_cache(
+ self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+ ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ groups: list of state groups for which we want to get the state.
+ cache: the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the
+ specific members state cache.
+ state_filter: The state filter used to fetch state from the
+ database.
Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
+ Tuple of dict of state_group_id to state map of entries in the
+ cache, and the state group ids either missing from the cache or
+ incomplete.
"""
results = {}
incomplete_groups = set()
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
@@ -22,9 +23,13 @@ import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
+# Used for generic functions below
+T = TypeVar("T")
+
@attr.s(slots=True)
class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types())
- def filter_state(self, state_dict):
+ def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
Args:
- state (dict[tuple[str, str], Any]): The state map to filter
+ state: The state map to filter
Returns:
- dict[tuple[str, str], Any]: The filtered state map
+ The filtered state map
"""
if self.is_full():
return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores
- def get_state_group_delta(self, state_group):
+ def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+ Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
"""
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
event_ids (iterable[str]): ids of the events
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
+ Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
for group, event_id_map in iteritems(group_to_ids)
}
- def _get_state_groups_from_groups(self, groups, state_filter):
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
diff --git a/synapse/types.py b/synapse/types.py
index cd996c0b5a..65e4d8c181 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -17,6 +17,7 @@ import re
import string
import sys
from collections import namedtuple
+from typing import Dict, Tuple, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Sized, Iterable, Container, TypeVar
+ from typing import Sized, Iterable, Container
T_co = TypeVar("T_co", covariant=True)
@@ -36,6 +37,12 @@ else:
__slots__ = ()
+# Define a state map type from type/state_key to T (usually an event ID or
+# event)
+T = TypeVar("T")
+StateMap = Dict[Tuple[str, str], T]
+
+
class Requester(
namedtuple(
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
|