summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py5
-rw-r--r--synapse/storage/events.py6
-rw-r--r--synapse/storage/state.py25
3 files changed, 35 insertions, 1 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9e348590ba..2210b3ddfb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -124,6 +124,11 @@ class Cache(object):
         self.sequence += 1
         self.cache.pop(keyargs, None)
 
+    def invalidate_all(self):
+        self.check_thread()
+        self.sequence += 1
+        self.cache.clear()
+
 
 def cached(max_entries=1000, num_args=1, lru=False):
     """ A method decorator that applies a memoizing cache around the function.
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 38395c66ab..52074b4cc8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -107,6 +107,8 @@ class EventsStore(SQLBaseStore):
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
+            txn.call_after(self.get_current_state_for_key.invalidate_all)
+
             self._simple_delete_txn(
                 txn,
                 table="current_state_events",
@@ -335,6 +337,10 @@ class EventsStore(SQLBaseStore):
             )
 
             if is_new_state and not context.rejected:
+                txn.call_after(
+                    self.get_current_state_for_key.invalidate,
+                    event.room_id, event.type, event.state_key
+                )
                 self._simple_upsert_txn(
                     txn,
                     "current_state_events",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dbc0e49c1f..6df7350552 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 
 from twisted.internet import defer
 
@@ -130,6 +130,12 @@ class StateStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
+        if event_type and state_key is not None:
+            result = yield self.get_current_state_for_key(
+                room_id, event_type, state_key
+            )
+            defer.returnValue(result)
+
         def f(txn):
             sql = (
                 "SELECT event_id FROM current_state_events"
@@ -153,6 +159,23 @@ class StateStore(SQLBaseStore):
         events = yield self.runInteraction("get_current_state", f)
         defer.returnValue(events)
 
+    @cached(num_args=3)
+    @defer.inlineCallbacks
+    def get_current_state_for_key(self, room_id, event_type, state_key):
+        def f(txn):
+            sql = (
+                "SELECT event_id FROM current_state_events"
+                " WHERE room_id = ? AND type = ? AND state_key = ?"
+            )
+
+            args = (room_id, event_type, state_key)
+            txn.execute(sql, args)
+            results = txn.fetchall()
+            return [r[0] for r in results]
+        event_ids = yield self.runInteraction("get_current_state_for_key", f)
+        events = yield self._get_events(event_ids, get_prev_content=False)
+        defer.returnValue(events)
+
 
 def _make_group_id(clock):
     return str(int(clock.time_msec())) + random_string(5)