summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py107
1 files changed, 41 insertions, 66 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4b136b3054..1b454a56a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -107,8 +107,9 @@ class StateHandler(object):
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key="",
-                          latest_event_ids=None):
+    def get_current_state(
+        self, room_id, event_type=None, state_key="", latest_event_ids=None
+    ):
         """ Retrieves the current state for the room. This is done by
         calling `get_latest_events_in_room` to get the leading edges of the
         event graph and then resolving any of the state conflicts.
@@ -137,8 +138,9 @@ class StateHandler(object):
             defer.returnValue(event)
             return
 
-        state_map = yield self.store.get_events(list(state.values()),
-                                                get_prev_content=False)
+        state_map = yield self.store.get_events(
+            list(state.values()), get_prev_content=False
+        )
         state = {
             key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
         }
@@ -220,9 +222,7 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                prev_state_ids = {
-                    (s.type, s.state_key): s.event_id for s in old_state
-                }
+                prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
                 if event.is_state():
                     current_state_ids = dict(prev_state_ids)
                     key = (event.type, event.state_key)
@@ -248,9 +248,7 @@ class StateHandler(object):
             # Let's just correctly fill out the context and create a
             # new state group for it.
 
-            prev_state_ids = {
-                (s.type, s.state_key): s.event_id for s in old_state
-            }
+            prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -282,7 +280,7 @@ class StateHandler(object):
         logger.debug("calling resolve_state_groups from compute_event_context")
 
         entry = yield self.resolve_state_groups_for_events(
-            event.room_id, event.prev_event_ids(),
+            event.room_id, event.prev_event_ids()
         )
 
         prev_state_ids = entry.state
@@ -305,9 +303,7 @@ class StateHandler(object):
                 # If the state at the event has a state group assigned then
                 # we can use that as the prev group
                 prev_group = entry.state_group
-                delta_ids = {
-                    key: event.event_id
-                }
+                delta_ids = {key: event.event_id}
             elif entry.prev_group:
                 # If the state at the event only has a prev group, then we can
                 # use that as a prev group too.
@@ -369,31 +365,31 @@ class StateHandler(object):
         # map from state group id to the state in that state group (where
         # 'state' is a map from state key to event id)
         # dict[int, dict[(str, str), str]]
-        state_groups_ids = yield self.store.get_state_groups_ids(
-            room_id, event_ids
-        )
+        state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
 
         if len(state_groups_ids) == 0:
-            defer.returnValue(_StateCacheEntry(
-                state={},
-                state_group=None,
-            ))
+            defer.returnValue(_StateCacheEntry(state={}, state_group=None))
         elif len(state_groups_ids) == 1:
             name, state_list = list(state_groups_ids.items()).pop()
 
             prev_group, delta_ids = yield self.store.get_state_group_delta(name)
 
-            defer.returnValue(_StateCacheEntry(
-                state=state_list,
-                state_group=name,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-            ))
+            defer.returnValue(
+                _StateCacheEntry(
+                    state=state_list,
+                    state_group=name,
+                    prev_group=prev_group,
+                    delta_ids=delta_ids,
+                )
+            )
 
         room_version = yield self.store.get_room_version(room_id)
 
         result = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, room_version, state_groups_ids, None,
+            room_id,
+            room_version,
+            state_groups_ids,
+            None,
             state_res_store=StateResolutionStore(self.store),
         )
         defer.returnValue(result)
@@ -403,27 +399,21 @@ class StateHandler(object):
         logger.info(
             "Resolving state for %s with %d groups", event.room_id, len(state_sets)
         )
-        state_set_ids = [{
-            (ev.type, ev.state_key): ev.event_id
-            for ev in st
-        } for st in state_sets]
-
-        state_map = {
-            ev.event_id: ev
-            for st in state_sets
-            for ev in st
-        }
+        state_set_ids = [
+            {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
+        ]
+
+        state_map = {ev.event_id: ev for st in state_sets for ev in st}
 
         with Measure(self.clock, "state._resolve_events"):
             new_state = yield resolve_events_with_store(
-                room_version, state_set_ids,
+                room_version,
+                state_set_ids,
                 event_map=state_map,
                 state_res_store=StateResolutionStore(self.store),
             )
 
-        new_state = {
-            key: state_map[ev_id] for key, ev_id in iteritems(new_state)
-        }
+        new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
 
         defer.returnValue(new_state)
 
@@ -434,6 +424,7 @@ class StateResolutionHandler(object):
     Note that the storage layer depends on this handler, so all functions must
     be storage-independent.
     """
+
     def __init__(self, hs):
         self.clock = hs.get_clock()
 
@@ -453,7 +444,7 @@ class StateResolutionHandler(object):
     @defer.inlineCallbacks
     @log_function
     def resolve_state_groups(
-        self, room_id, room_version, state_groups_ids, event_map, state_res_store,
+        self, room_id, room_version, state_groups_ids, event_map, state_res_store
     ):
         """Resolves conflicts between a set of state groups
 
@@ -480,10 +471,7 @@ class StateResolutionHandler(object):
         Returns:
             Deferred[_StateCacheEntry]: resolved state
         """
-        logger.debug(
-            "resolve_state_groups state_groups %s",
-            state_groups_ids.keys()
-        )
+        logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
 
         group_names = frozenset(state_groups_ids.keys())
 
@@ -540,10 +528,7 @@ class StateResolutionHandler(object):
             defer.returnValue(cache)
 
 
-def _make_state_cache_entry(
-    new_state,
-    state_groups_ids,
-):
+def _make_state_cache_entry(new_state, state_groups_ids):
     """Given a resolved state, and a set of input state groups, pick one to base
     a new state group on (if any), and return an appropriately-constructed
     _StateCacheEntry.
@@ -573,10 +558,7 @@ def _make_state_cache_entry(
         old_state_event_ids = set(itervalues(state))
         if new_state_event_ids == old_state_event_ids:
             # got an exact match.
-            return _StateCacheEntry(
-                state=new_state,
-                state_group=sg,
-            )
+            return _StateCacheEntry(state=new_state, state_group=sg)
 
     # TODO: We want to create a state group for this set of events, to
     # increase cache hits, but we need to make sure that it doesn't
@@ -587,20 +569,13 @@ def _make_state_cache_entry(
     delta_ids = None
 
     for old_group, old_state in iteritems(state_groups_ids):
-        n_delta_ids = {
-            k: v
-            for k, v in iteritems(new_state)
-            if old_state.get(k) != v
-        }
+        n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
         if not delta_ids or len(n_delta_ids) < len(delta_ids):
             prev_group = old_group
             delta_ids = n_delta_ids
 
     return _StateCacheEntry(
-        state=new_state,
-        state_group=None,
-        prev_group=prev_group,
-        delta_ids=delta_ids,
+        state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
     )
 
 
@@ -629,11 +604,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     v = KNOWN_ROOM_VERSIONS[room_version]
     if v.state_res == StateResolutionVersions.V1:
         return v1.resolve_events_with_store(
-            state_sets, event_map, state_res_store.get_events,
+            state_sets, event_map, state_res_store.get_events
         )
     else:
         return v2.resolve_events_with_store(
-            room_version, state_sets, event_map, state_res_store,
+            room_version, state_sets, event_map, state_res_store
         )