diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 53 |
1 files changed, 45 insertions, 8 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dbc0e49c1f..b24de34f23 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from ._base import SQLBaseStore, cached from twisted.internet import defer @@ -43,6 +43,7 @@ class StateStore(SQLBaseStore): * `state_groups_state`: Maps state group to state events. """ + @defer.inlineCallbacks def get_state_groups(self, event_ids): """ Get the state groups for the given list of event_ids @@ -71,17 +72,29 @@ class StateStore(SQLBaseStore): retcol="event_id", ) - state = self._get_events_txn(txn, state_ids) - - res[group] = state + res[group] = state_ids return res - return self.runInteraction( + states = yield self.runInteraction( "get_state_groups", f, ) + @defer.inlineCallbacks + def c(vals): + vals[:] = yield self._get_events(vals, get_prev_content=False) + + yield defer.gatherResults( + [ + c(vals) + for vals in states.values() + ], + consumeErrors=True, + ) + + defer.returnValue(states) + def _store_state_groups_txn(self, txn, event, context): if context.current_state is None: return @@ -130,6 +143,12 @@ class StateStore(SQLBaseStore): @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): + if event_type and state_key is not None: + result = yield self.get_current_state_for_key( + room_id, event_type, state_key + ) + defer.returnValue(result) + def f(txn): sql = ( "SELECT event_id FROM current_state_events" @@ -146,11 +165,29 @@ class StateStore(SQLBaseStore): args = (room_id, ) txn.execute(sql, args) - results = self.cursor_to_dict(txn) + results = txn.fetchall() - return self._parse_events_txn(txn, results) + return [r[0] for r in results] - events = yield self.runInteraction("get_current_state", f) + event_ids = yield self.runInteraction("get_current_state", f) + events = yield self._get_events(event_ids, get_prev_content=False) + defer.returnValue(events) + + @cached(num_args=3) + @defer.inlineCallbacks + def get_current_state_for_key(self, room_id, event_type, state_key): + def f(txn): + sql = ( + "SELECT event_id FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?" + ) + + args = (room_id, event_type, state_key) + txn.execute(sql, args) + results = txn.fetchall() + return [r[0] for r in results] + event_ids = yield self.runInteraction("get_current_state_for_key", f) + events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) |