diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 6223daf522..463ce58dae 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@ import heapq
import logging
from collections import defaultdict, namedtuple
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -52,6 +53,10 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
+
logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
@@ -74,7 +79,7 @@ _NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
-def _gen_state_id():
+def _gen_state_id() -> str:
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
@@ -109,7 +114,7 @@ class _StateCacheEntry:
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
- self.state_id = state_group
+ self.state_id: Union[str, int] = state_group
else:
self.state_id = _gen_state_id()
@@ -122,7 +127,7 @@ class StateHandler:
where necessary
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state_store = hs.get_storage().state
@@ -507,7 +512,7 @@ class StateResolutionHandler:
be storage-independent.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
@@ -636,16 +641,20 @@ class StateResolutionHandler:
"""
try:
with Measure(self.clock, "state._resolve_events") as m:
- v = KNOWN_ROOM_VERSIONS[room_version]
- if v.state_res == StateResolutionVersions.V1:
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ if room_version_obj.state_res == StateResolutionVersions.V1:
return await v1.resolve_events_with_store(
- room_id, state_sets, event_map, state_res_store.get_events
+ room_id,
+ room_version_obj,
+ state_sets,
+ event_map,
+ state_res_store.get_events,
)
else:
return await v2.resolve_events_with_store(
self.clock,
room_id,
- room_version,
+ room_version_obj,
state_sets,
event_map,
state_res_store,
@@ -653,13 +662,15 @@ class StateResolutionHandler:
finally:
self._record_state_res_metrics(room_id, m.get_resource_usage())
- def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage):
+ def _record_state_res_metrics(
+ self, room_id: str, rusage: ContextResourceUsage
+ ) -> None:
room_metrics = self._state_res_metrics[room_id]
room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime
room_metrics.db_time += rusage.db_txn_duration_sec
room_metrics.db_events += rusage.evt_db_fetch_count
- def _report_metrics(self):
+ def _report_metrics(self) -> None:
if not self._state_res_metrics:
# no state res has happened since the last iteration: don't bother logging.
return
@@ -769,16 +780,13 @@ def _make_state_cache_entry(
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class StateResolutionStore:
"""Interface that allows state resolution algorithms to access the database
in well defined way.
-
- Args:
- store (DataStore)
"""
- store = attr.ib()
+ store: "DataStore"
def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 267193cedf..92336d7cc8 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -29,7 +29,7 @@ from typing import (
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
@@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
+ room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@@ -104,7 +105,7 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
auth_events = _create_auth_events_from_maps(
- unconflicted_state, conflicted_state, state_map
+ room_version, unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.values())
@@ -132,7 +133,7 @@ async def resolve_events_with_store(
state_map.update(state_map_new)
return _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
+ room_version, unconflicted_state, conflicted_state, auth_events, state_map
)
@@ -187,6 +188,7 @@ def _seperate(
def _create_auth_events_from_maps(
+ room_version: RoomVersion,
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
@@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
"""
Args:
+ room_version: The room version.
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
@@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
- keys = event_auth.auth_types_for_event(state_map[event_id])
+ keys = event_auth.auth_types_for_event(
+ room_version, state_map[event_id]
+ )
for key in keys:
if key not in auth_events:
auth_event_id = unconflicted_state.get(key, None)
@@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
def _resolve_with_state(
+ room_version: RoomVersion,
unconflicted_state_ids: MutableStateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
@@ -235,7 +241,9 @@ def _resolve_with_state(
}
try:
- resolved_state = _resolve_state_events(conflicted_state, auth_events)
+ resolved_state = _resolve_state_events(
+ room_version, conflicted_state, auth_events
+ )
except Exception:
logger.exception("Failed to resolve state")
raise
@@ -248,7 +256,9 @@ def _resolve_with_state(
def _resolve_state_events(
- conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+ room_version: RoomVersion,
+ conflicted_state: StateMap[List[EventBase]],
+ auth_events: MutableStateMap[EventBase],
) -> StateMap[EventBase]:
"""This is where we actually decide which of the conflicted state to
use.
@@ -263,21 +273,27 @@ def _resolve_state_events(
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
+ resolved_state[POWER_KEY] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = _resolve_auth_events(events, auth_events)
+ resolved_state[key] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = _resolve_auth_events(events, auth_events)
+ resolved_state[key] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
@@ -290,12 +306,14 @@ def _resolve_state_events(
def _resolve_auth_events(
- events: List[EventBase], auth_events: StateMap[EventBase]
+ room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
- key for event in events for key in event_auth.auth_types_for_event(event)
+ key
+ for event in events
+ for key in event_auth.auth_types_for_event(room_version, event)
}
new_auth_events = {}
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e66e6571c8..7b1e8361de 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -36,7 +36,7 @@ import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
@@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
async def resolve_events_with_store(
clock: Clock,
room_id: str,
- room_version: str,
+ room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
@@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
clock: Clock,
room_id: str,
- room_version: str,
+ room_version: RoomVersion,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
@@ -519,7 +519,6 @@ async def _iterative_auth_checks(
Returns the final updated state
"""
resolved_state = dict(base_state)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for idx, event_id in enumerate(event_ids, start=1):
event = event_map[event_id]
@@ -538,7 +537,7 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None:
auth_events[(ev.type, ev.state_key)] = ev
- for key in event_auth.auth_types_for_event(event):
+ for key in event_auth.auth_types_for_event(room_version, event):
if key in resolved_state:
ev_id = resolved_state[key]
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@@ -548,7 +547,7 @@ async def _iterative_auth_checks(
try:
event_auth.check(
- room_version_obj,
+ room_version,
event,
auth_events,
do_sig_check=False,
|