diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..940a5f7e08 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cached
from twisted.internet import defer
@@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
+ @cached()
+ @defer.inlineCallbacks
+ def get_all_server_verify_keys(self, server_name):
+ rows = yield self._simple_select_list(
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ },
+ retcols=["key_id", "verify_key"],
+ desc="get_all_server_verify_keys",
+ )
+
+ defer.returnValue({
+ row["key_id"]: decode_verify_key_bytes(
+ row["key_id"], str(row["verify_key"])
+ )
+ for row in rows
+ })
+
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given
@@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- sql = (
- "SELECT key_id, verify_key FROM server_signature_keys"
- " WHERE server_name = ?"
- " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
- )
-
- rows = yield self._execute_and_decode(
- "get_server_verify_keys", sql, server_name, *key_ids
- )
-
- keys = []
- for row in rows:
- key_id = row["key_id"]
- key_bytes = row["verify_key"]
- key = decode_verify_key_bytes(key_id, str(key_bytes))
- keys.append(key)
- defer.returnValue(keys)
+ keys = yield self.get_all_server_verify_keys(server_name)
+ defer.returnValue({
+ k: keys[k]
+ for k in key_ids
+ if k in keys and keys[k]
+ })
+ @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
@@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
- return self._simple_upsert(
+ yield self._simple_upsert(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
@@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
+ self.get_all_server_verify_keys.invalidate(server_name)
+
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
@@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes),
},
+ desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f2b17f29ea..d7844edee3 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
defer.returnValue(dict(state_list))
@cached(num_args=1)
- def _fetch_events_for_group(self, state_group, events):
+ def _fetch_events_for_group(self, key, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
- lambda evs: (state_group, evs)
+ lambda evs: (key, evs)
)
def _store_state_groups_txn(self, txn, event, context):
@@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
+ @defer.inlineCallbacks
+ def get_state_for_events(self, room_id, event_ids):
+ def f(txn):
+ groups = set()
+ event_to_group = {}
+ for event_id in event_ids:
+ # TODO: Remove this loop.
+ group = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ event_to_group[event_id] = group
+ groups.add(group)
+
+ group_to_state_ids = {}
+ for group in groups:
+ state_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+
+ group_to_state_ids[group] = state_ids
+
+ return event_to_group, group_to_state_ids
+
+ res = yield self.runInteraction(
+ "annotate_events_with_state_groups",
+ f,
+ )
+
+ event_to_group, group_to_state_ids = res
+
+ state_list = yield defer.gatherResults(
+ [
+ self._fetch_events_for_group(group, vals)
+ for group, vals in group_to_state_ids.items()
+ ],
+ consumeErrors=True,
+ )
+
+ state_dict = {
+ group: {
+ (ev.type, ev.state_key): ev
+ for ev in state
+ }
+ for group, state in state_list
+ }
+
+ defer.returnValue([
+ state_dict.get(event_to_group.get(event, None), None)
+ for event in event_ids
+ ])
+
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)
|