summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-01-16 13:31:22 +0000
committerGitHub <noreply@github.com>2020-01-16 13:31:22 +0000
commitd386f2f339c839ff6ec8d656492dd635dc26f811 (patch)
treef127b8130fb9d778d863300e97c3ac55945e2e61
parentAdd tips for the changelog to the pull request template (#6663) (diff)
downloadsynapse-d386f2f339c839ff6ec8d656492dd635dc26f811.tar.xz
Add StateMap type alias (#6715)
-rw-r--r--changelog.d/6715.misc1
-rw-r--r--synapse/api/auth.py8
-rw-r--r--synapse/events/snapshot.py11
-rw-r--r--synapse/federation/sender/per_destination_queue.py3
-rw-r--r--synapse/handlers/admin.py25
-rw-r--r--synapse/handlers/federation.py10
-rw-r--r--synapse/handlers/room.py24
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/state/v1.py5
-rw-r--r--synapse/state/v2.py9
-rw-r--r--synapse/storage/data_stores/main/state.py11
-rw-r--r--synapse/storage/data_stores/state/store.py52
-rw-r--r--synapse/storage/state.py35
-rw-r--r--synapse/types.py9
14 files changed, 115 insertions, 93 deletions
diff --git a/changelog.d/6715.misc b/changelog.d/6715.misc
new file mode 100644
index 0000000000..8876b0446d
--- /dev/null
+++ b/changelog.d/6715.misc
@@ -0,0 +1 @@
+Add StateMap type alias to simplify types.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index abbc7079a3..2cbfab2569 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Tuple
 
 from six import itervalues
 
@@ -35,7 +34,7 @@ from synapse.api.errors import (
     ResourceLimitError,
 )
 from synapse.config.server import is_threepid_reserved
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
 from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
@@ -509,10 +508,7 @@ class Auth(object):
         return self.store.is_server_admin(user)
 
     def compute_auth_events(
-        self,
-        event,
-        current_state_ids: Dict[Tuple[str, str], str],
-        for_verification: bool = False,
+        self, event, current_state_ids: StateMap[str], for_verification: bool = False,
     ):
         """Given an event and current state return the list of event IDs used
         to auth an event.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a44baea365..9ea85e93e6 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Union
 
 from six import iteritems
 
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import StateMap
 
 
 @attr.s(slots=True)
@@ -106,13 +107,11 @@ class EventContext:
     _state_group = attr.ib(default=None, type=Optional[int])
     state_group_before_event = attr.ib(default=None, type=Optional[int])
     prev_group = attr.ib(default=None, type=Optional[int])
-    delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+    delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
     app_service = attr.ib(default=None, type=Optional[ApplicationService])
 
-    _current_state_ids = attr.ib(
-        default=None, type=Optional[Dict[Tuple[str, str], str]]
-    )
-    _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+    _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+    _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
 
     @staticmethod
     def with_state(
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index a5b36b1827..5012aaea35 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
+from synapse.types import StateMap
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
 # This is defined in the Matrix spec and enforced by the receiver.
@@ -77,7 +78,7 @@ class PerDestinationQueue(object):
         # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
         # based on their key (e.g. typing events by room_id)
         # Map of (edu_type, key) -> Edu
-        self._pending_edus_keyed = {}  # type: dict[tuple[str, str], Edu]
+        self._pending_edus_keyed = {}  # type: StateMap[Edu]
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a9407553b4..60a7c938bc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,9 +14,11 @@
 # limitations under the License.
 
 import logging
+from typing import List
 
 from synapse.api.constants import Membership
-from synapse.types import RoomStreamToken
+from synapse.events import FrozenEvent
+from synapse.types import RoomStreamToken, StateMap
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -259,35 +261,26 @@ class ExfiltrationWriter(object):
     """Interface used to specify how to write exported data.
     """
 
-    def write_events(self, room_id, events):
+    def write_events(self, room_id: str, events: List[FrozenEvent]):
         """Write a batch of events for a room.
-
-        Args:
-            room_id (str)
-            events (list[FrozenEvent])
         """
         pass
 
-    def write_state(self, room_id, event_id, state):
+    def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
         """Write the state at the given event in the room.
 
         This only gets called for backward extremities rather than for each
         event.
-
-        Args:
-            room_id (str)
-            event_id (str)
-            state (dict[tuple[str, str], FrozenEvent])
         """
         pass
 
-    def write_invite(self, room_id, event, state):
+    def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
         """Write an invite for the room, with associated invite state.
 
         Args:
-            room_id (str)
-            event (FrozenEvent)
-            state (dict[tuple[str, str], dict]): A subset of the state at the
+            room_id
+            event
+            state: A subset of the state at the
                 invite, with a subset of the event keys (type, state_key
                 content and sender)
         """
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 61b6713c88..d4f9a792fc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import StateMap, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -89,7 +89,7 @@ class _NewEventInfo:
 
     event = attr.ib(type=EventBase)
     state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
-    auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
 
 
 def shortstr(iterable, maxitems=5):
@@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
                     ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(
-                        ours.values()
-                    )  # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())  # type: list[StateMap[str]]
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-        auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+        auth_events: Optional[StateMap[EventBase]],
         backfilled: bool,
     ):
         """
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9cab2adbfb..9f50196ea7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    Requester,
+    RoomAlias,
+    RoomID,
+    RoomStreamToken,
+    StateMap,
+    StreamToken,
+    UserID,
+)
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
@@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _update_upgraded_room_pls(
-        self, requester, old_room_id, new_room_id, old_room_state,
+        self,
+        requester: Requester,
+        old_room_id: str,
+        new_room_id: str,
+        old_room_state: StateMap[str],
     ):
         """Send updated power levels in both rooms after an upgrade
 
         Args:
-            requester (synapse.types.Requester): the user requesting the upgrade
-            old_room_id (str): the id of the room to be replaced
-            new_room_id (str): the id of the replacement room
-            old_room_state (dict[tuple[str, str], str]): the state map for the old room
+            requester: the user requesting the upgrade
+            old_room_id: the id of the room to be replaced
+            new_room_id: the id of the replacement room
+            old_room_state: the state map for the old room
 
         Returns:
             Deferred
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 5accc071ab..cacd0c0c2b 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional
 
 from six import iteritems, itervalues
 
@@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
 def resolve_events_with_store(
     room_id: str,
     room_version: str,
-    state_sets: List[Dict[Tuple[str, str], str]],
+    state_sets: List[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "StateResolutionStore",
 ):
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index b2f9865f39..d6c34ce3b7 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,7 +15,7 @@
 
 import hashlib
 import logging
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional
 
 from six import iteritems, iterkeys, itervalues
 
@@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 @defer.inlineCallbacks
 def resolve_events_with_store(
     room_id: str,
-    state_sets: List[Dict[Tuple[str, str], str]],
+    state_sets: List[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_map_factory: Callable,
 ):
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 72fb8a6317..6216fdd204 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,7 @@
 import heapq
 import itertools
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
 
 from six import iteritems, itervalues
 
@@ -27,6 +27,7 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.events import EventBase
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
 def resolve_events_with_store(
     room_id: str,
     room_version: str,
-    state_sets: List[Dict[Tuple[str, str], str]],
+    state_sets: List[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
 ):
@@ -393,12 +394,12 @@ def _iterative_auth_checks(
         room_id (str)
         room_version (str)
         event_ids (list[str]): Ordered list of events to apply auth checks to
-        base_state (dict[tuple[str, str], str]): The set of state to start with
+        base_state (StateMap[str]): The set of state to start with
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
 
     Returns:
-        Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+        Deferred[StateMap[str]]: Returns the final updated state
     """
     resolved_state = base_state.copy()
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d07440e3ed..33bebd1c48 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+    def get_filtered_current_state_ids(
+        self, room_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
 
         Args:
-            room_id (str)
-            state_filter (StateFilter): The state filter used to fetch state
+            room_id
+            state_filter: The state filter used to fetch state
                 from the database.
 
         Returns:
-            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
-            event ID.
+            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
         """
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index d53695f238..c4ee9b7ccb 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -15,6 +15,7 @@
 
 import logging
 from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
 
 from six import iteritems
 from six.moves import range
@@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
+from synapse.types import StateMap
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, state_filter):
-        """Returns the state groups for a given set of groups, filtering on
-        types of state events.
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
+        """Returns the state groups for a given set of groups from the
+        database, filtering on types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         results = {}
 
@@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         return state_filter.filter_state(state_dict_ids), not missing_types
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
+            groups: list of state groups for which we want
                 to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         member_filter, non_member_filter = state_filter.get_member_split()
@@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state
 
-    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+    def _get_state_for_groups_using_cache(
+        self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            cache (DictionaryCache): the cache of group ids to state dicts which
-                we will pass through - either the normal state cache or the specific
-                members state cache.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            groups: list of state groups for which we want to get the state.
+            cache: the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the
+                specific members state cache.
+            state_filter: The state filter used to fetch state from the
+                database.
 
         Returns:
-            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
-            dict of state_group_id -> (dict of (type, state_key) -> event id)
-            of entries in the cache, and the state group ids either missing
-            from the cache or incomplete.
+            Tuple of dict of state_group_id to state map of entries in the
+            cache, and the state group ids either missing from the cache or
+            incomplete.
         """
         results = {}
         incomplete_groups = set()
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Iterable, List, TypeVar
 
 from six import iteritems, itervalues
 
@@ -22,9 +23,13 @@ import attr
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
+# Used for generic functions below
+T = TypeVar("T")
+
 
 @attr.s(slots=True)
 class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict):
+    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
         """Returns the state filtered with by this StateFilter
 
         Args:
-            state (dict[tuple[str, str], Any]): The state map to filter
+            state: The state map to filter
 
         Returns:
-            dict[tuple[str, str], Any]: The filtered state map
+            The filtered state map
         """
         if self.is_full():
             return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    def get_state_group_delta(self, state_group):
+    def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
                 (prev_group, delta_ids)
         """
 
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
             event_ids (iterable[str]): ids of the events
 
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
+            Deferred[dict[int, StateMap[str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
             for group, event_id_map in iteritems(group_to_ids)
         }
 
-    def _get_state_groups_from_groups(self, groups, state_filter):
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
             state_filter (StateFilter): The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)
 
diff --git a/synapse/types.py b/synapse/types.py
index cd996c0b5a..65e4d8c181 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -17,6 +17,7 @@ import re
 import string
 import sys
 from collections import namedtuple
+from typing import Dict, Tuple, TypeVar
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError
 if sys.version_info[:3] >= (3, 6, 0):
     from typing import Collection
 else:
-    from typing import Sized, Iterable, Container, TypeVar
+    from typing import Sized, Iterable, Container
 
     T_co = TypeVar("T_co", covariant=True)
 
@@ -36,6 +37,12 @@ else:
         __slots__ = ()
 
 
+# Define a state map type from type/state_key to T (usually an event ID or
+# event)
+T = TypeVar("T")
+StateMap = Dict[Tuple[str, str], T]
+
+
 class Requester(
     namedtuple(
         "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]