diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/keys.py | 50 | ||||
-rw-r--r-- | synapse/storage/state.py | 63 |
2 files changed, 92 insertions, 21 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 5bdf497b93..940a5f7e08 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from _base import SQLBaseStore, cached from twisted.internet import defer @@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) + @cached() + @defer.inlineCallbacks + def get_all_server_verify_keys(self, server_name): + rows = yield self._simple_select_list( + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + }, + retcols=["key_id", "verify_key"], + desc="get_all_server_verify_keys", + ) + + defer.returnValue({ + row["key_id"]: decode_verify_key_bytes( + row["key_id"], str(row["verify_key"]) + ) + for row in rows + }) + @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): """Retrieve the NACL verification key for a given server for the given @@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - sql = ( - "SELECT key_id, verify_key FROM server_signature_keys" - " WHERE server_name = ?" - " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" - ) - - rows = yield self._execute_and_decode( - "get_server_verify_keys", sql, server_name, *key_ids - ) - - keys = [] - for row in rows: - key_id = row["key_id"] - key_bytes = row["verify_key"] - key = decode_verify_key_bytes(key_id, str(key_bytes)) - keys.append(key) - defer.returnValue(keys) + keys = yield self.get_all_server_verify_keys(server_name) + defer.returnValue({ + k: keys[k] + for k in key_ids + if k in keys and keys[k] + }) + @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. @@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - return self._simple_upsert( + yield self._simple_upsert( table="server_signature_keys", keyvalues={ "server_name": server_name, @@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) + self.get_all_server_verify_keys.invalidate(server_name) + def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): """Stores the JSON bytes for a set of keys from a server @@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore): "ts_valid_until_ms": ts_expires_ms, "key_json": buffer(key_json_bytes), }, + desc="store_server_keys_json", ) def get_server_keys_json(self, server_keys): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f2b17f29ea..d7844edee3 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -92,11 +92,11 @@ class StateStore(SQLBaseStore): defer.returnValue(dict(state_list)) @cached(num_args=1) - def _fetch_events_for_group(self, state_group, events): + def _fetch_events_for_group(self, key, events): return self._get_events( events, get_prev_content=False ).addCallback( - lambda evs: (state_group, evs) + lambda evs: (key, evs) ) def _store_state_groups_txn(self, txn, event, context): @@ -194,6 +194,65 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids): + def f(txn): + groups = set() + event_to_group = {} + for event_id in event_ids: + # TODO: Remove this loop. + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + event_to_group[event_id] = group + groups.add(group) + + group_to_state_ids = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + + group_to_state_ids[group] = state_ids + + return event_to_group, group_to_state_ids + + res = yield self.runInteraction( + "annotate_events_with_state_groups", + f, + ) + + event_to_group, group_to_state_ids = res + + state_list = yield defer.gatherResults( + [ + self._fetch_events_for_group(group, vals) + for group, vals in group_to_state_ids.items() + ], + consumeErrors=True, + ) + + state_dict = { + group: { + (ev.type, ev.state_key): ev + for ev in state + } + for group, state in state_list + } + + defer.returnValue([ + state_dict.get(event_to_group.get(event, None), None) + for event in event_ids + ]) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) |