summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/keys.py50
-rw-r--r--synapse/storage/state.py63
2 files changed, 92 insertions, 21 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..940a5f7e08 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cached
 
 from twisted.internet import defer
 
@@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
             desc="store_server_certificate",
         )
 
+    @cached()
+    @defer.inlineCallbacks
+    def get_all_server_verify_keys(self, server_name):
+        rows = yield self._simple_select_list(
+            table="server_signature_keys",
+            keyvalues={
+                "server_name": server_name,
+            },
+            retcols=["key_id", "verify_key"],
+            desc="get_all_server_verify_keys",
+        )
+
+        defer.returnValue({
+            row["key_id"]: decode_verify_key_bytes(
+                row["key_id"], str(row["verify_key"])
+            )
+            for row in rows
+        })
+
     @defer.inlineCallbacks
     def get_server_verify_keys(self, server_name, key_ids):
         """Retrieve the NACL verification key for a given server for the given
@@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
         Returns:
             (list of VerifyKey): The verification keys.
         """
-        sql = (
-            "SELECT key_id, verify_key FROM server_signature_keys"
-            " WHERE server_name = ?"
-            " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
-        )
-
-        rows = yield self._execute_and_decode(
-            "get_server_verify_keys", sql, server_name, *key_ids
-        )
-
-        keys = []
-        for row in rows:
-            key_id = row["key_id"]
-            key_bytes = row["verify_key"]
-            key = decode_verify_key_bytes(key_id, str(key_bytes))
-            keys.append(key)
-        defer.returnValue(keys)
+        keys = yield self.get_all_server_verify_keys(server_name)
+        defer.returnValue({
+            k: keys[k]
+            for k in key_ids
+            if k in keys and keys[k]
+        })
 
+    @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
                                 verify_key):
         """Stores a NACL verification key for the given server.
@@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
             ts_now_ms (int): The time now in milliseconds
             verification_key (VerifyKey): The NACL verify key.
         """
-        return self._simple_upsert(
+        yield self._simple_upsert(
             table="server_signature_keys",
             keyvalues={
                 "server_name": server_name,
@@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
             desc="store_server_verify_key",
         )
 
+        self.get_all_server_verify_keys.invalidate(server_name)
+
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
         """Stores the JSON bytes for a set of keys from a server
@@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
                 "ts_valid_until_ms": ts_expires_ms,
                 "key_json": buffer(key_json_bytes),
             },
+            desc="store_server_keys_json",
         )
 
     def get_server_keys_json(self, server_keys):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f2b17f29ea..d7844edee3 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
         defer.returnValue(dict(state_list))
 
     @cached(num_args=1)
-    def _fetch_events_for_group(self, state_group, events):
+    def _fetch_events_for_group(self, key, events):
         return self._get_events(
             events, get_prev_content=False
         ).addCallback(
-            lambda evs: (state_group, evs)
+            lambda evs: (key, evs)
         )
 
     def _store_state_groups_txn(self, txn, event, context):
@@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
+    @defer.inlineCallbacks
+    def get_state_for_events(self, room_id, event_ids):
+        def f(txn):
+            groups = set()
+            event_to_group = {}
+            for event_id in event_ids:
+                # TODO: Remove this loop.
+                group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="event_to_state_groups",
+                    keyvalues={"event_id": event_id},
+                    retcol="state_group",
+                    allow_none=True,
+                )
+                if group:
+                    event_to_group[event_id] = group
+                    groups.add(group)
+
+            group_to_state_ids = {}
+            for group in groups:
+                state_ids = self._simple_select_onecol_txn(
+                    txn,
+                    table="state_groups_state",
+                    keyvalues={"state_group": group},
+                    retcol="event_id",
+                )
+
+                group_to_state_ids[group] = state_ids
+
+            return event_to_group, group_to_state_ids
+
+        res = yield self.runInteraction(
+            "annotate_events_with_state_groups",
+            f,
+        )
+
+        event_to_group, group_to_state_ids = res
+
+        state_list = yield defer.gatherResults(
+            [
+                self._fetch_events_for_group(group, vals)
+                for group, vals in group_to_state_ids.items()
+            ],
+            consumeErrors=True,
+        )
+
+        state_dict = {
+            group: {
+                (ev.type, ev.state_key): ev
+                for ev in state
+            }
+            for group, state in state_list
+        }
+
+        defer.returnValue([
+            state_dict.get(event_to_group.get(event, None), None)
+            for event in event_ids
+        ])
+
 
 def _make_group_id(clock):
     return str(int(clock.time_msec())) + random_string(5)