diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 8c091d07c9..222daa0b28 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -23,16 +23,15 @@ from frozendict import frozendict
from twisted.internet import defer
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, RoomVersions
from synapse.events.snapshot import EventContext
+from synapse.state import v1
from synapse.util.async import Linearizer
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
-from .v1 import resolve_events_with_factory, resolve_events_with_state_map
-
logger = logging.getLogger(__name__)
@@ -263,8 +262,14 @@ class StateHandler(object):
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
+ if event.type == EventTypes.Create:
+ room_version = event.content.get("room_version", RoomVersions.V1)
+ else:
+ room_version = None
+
entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events],
+ explicit_room_version=room_version,
)
prev_state_ids = entry.state
@@ -332,13 +337,17 @@ class StateHandler(object):
defer.returnValue(context)
@defer.inlineCallbacks
- def resolve_state_groups_for_events(self, room_id, event_ids):
+ def resolve_state_groups_for_events(self, room_id, event_ids,
+ explicit_room_version=None):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
- room_id (str):
- event_ids (list[str]):
+ room_id (str)
+ event_ids (list[str])
+ explicit_room_version (str|None): If set uses the the given room
+ version to choose the resolution algorithm. If None, then
+ checks the database for room version.
Returns:
Deferred[_StateCacheEntry]: resolved state
@@ -364,8 +373,13 @@ class StateHandler(object):
delta_ids=delta_ids,
))
+ room_version = explicit_room_version
+ if not room_version:
+ room_version = yield self.store.get_room_version(room_id)
+
result = yield self._state_resolution_handler.resolve_state_groups(
- room_id, state_groups_ids, None, self._state_map_factory,
+ room_id, room_version, state_groups_ids, None,
+ self._state_map_factory,
)
defer.returnValue(result)
@@ -374,7 +388,8 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False,
)
- def resolve_events(self, state_sets, event):
+ @defer.inlineCallbacks
+ def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -389,14 +404,18 @@ class StateHandler(object):
for ev in st
}
+ room_version = yield self.store.get_room_version(event.room_id)
+
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events_with_state_map(state_set_ids, state_map)
+ new_state = resolve_events_with_state_map(
+ room_version, state_set_ids, state_map,
+ )
new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
- return new_state
+ defer.returnValue(new_state)
class StateResolutionHandler(object):
@@ -429,7 +448,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
- self, room_id, state_groups_ids, event_map, state_map_factory,
+ self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
):
"""Resolves conflicts between a set of state groups
@@ -438,6 +457,7 @@ class StateResolutionHandler(object):
Args:
room_id (str): room we are resolving for (used for logging)
+ room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
@@ -491,6 +511,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
+ room_version,
list(itervalues(state_groups_ids)),
event_map=event_map,
state_map_factory=state_map_factory,
@@ -572,3 +593,64 @@ def _make_state_cache_entry(
prev_group=prev_group,
delta_ids=delta_ids,
)
+
+
+def resolve_events_with_state_map(room_version, state_sets, state_map):
+ """
+ Args:
+ room_version(str): Version of the room
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
+
+ Returns
+ dict[(str, str), str]:
+ a map from (type, state_key) to event_id.
+ """
+ if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
+ return v1.resolve_events_with_state_map(
+ state_sets, state_map,
+ )
+ else:
+ # This should only happen if we added a version but forgot to add it to
+ # the list above.
+ raise Exception(
+ "No state resolution algorithm defined for version %r" % (room_version,)
+ )
+
+
+def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
+ """
+ Args:
+ room_version(str): Version of the room
+
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
+ """
+ if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
+ return v1.resolve_events_with_factory(
+ state_sets, event_map, state_map_factory,
+ )
+ else:
+ # This should only happen if we added a version but forgot to add it to
+ # the list above.
+ raise Exception(
+ "No state resolution algorithm defined for version %r" % (room_version,)
+ )
|