diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dbc0e49c1f..6df7350552 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from twisted.internet import defer
@@ -130,6 +130,12 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
+ if event_type and state_key is not None:
+ result = yield self.get_current_state_for_key(
+ room_id, event_type, state_key
+ )
+ defer.returnValue(result)
+
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
@@ -153,6 +159,23 @@ class StateStore(SQLBaseStore):
events = yield self.runInteraction("get_current_state", f)
defer.returnValue(events)
+ @cached(num_args=3)
+ @defer.inlineCallbacks
+ def get_current_state_for_key(self, room_id, event_type, state_key):
+ def f(txn):
+ sql = (
+ "SELECT event_id FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?"
+ )
+
+ args = (room_id, event_type, state_key)
+ txn.execute(sql, args)
+ results = txn.fetchall()
+ return [r[0] for r in results]
+ event_ids = yield self.runInteraction("get_current_state_for_key", f)
+ events = yield self._get_events(event_ids, get_prev_content=False)
+ defer.returnValue(events)
+
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)
|