diff --git a/synapse/state.py b/synapse/state.py
index 8a556a27f6..cbb4243fad 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -136,6 +136,39 @@ class StateHandler(object):
defer.returnValue(res[1].values())
@defer.inlineCallbacks
+ def annotate_context_with_state(self, event, context):
+ if event.is_state():
+ ret = yield self.resolve_state_groups(
+ [e for e, _ in event.prev_events],
+ event_type=event.event_type,
+ state_key=event.state_key,
+ )
+ else:
+ ret = yield self.resolve_state_groups(
+ [e for e, _ in event.prev_events],
+ )
+
+ group, curr_state, prev_state = ret
+
+ context.current_state = curr_state
+
+ prev_state = yield self.store.add_event_hashes(
+ prev_state
+ )
+
+ if hasattr(event, "auth_events") and event.auth_events:
+ auth_ids = zip(*event.auth_events)[0]
+ context.auth_events = {
+ k: v
+ for k, v in context.current_state.items()
+ if v.event_id in auth_ids
+ }
+
+ defer.returnValue(
+ (group, prev_state)
+ )
+
+ @defer.inlineCallbacks
@log_function
def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
""" Given a list of event_ids this method fetches the state at each
|