diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 2 | ||||
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 125 | ||||
-rw-r--r-- | synapse/storage/keys.py | 50 | ||||
-rw-r--r-- | synapse/storage/schema/delta/21/end_to_end_keys.sql | 34 | ||||
-rw-r--r-- | synapse/storage/state.py | 63 |
5 files changed, 253 insertions, 21 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 2bc88a7954..71d5d92500 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,6 +37,7 @@ from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore +from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore @@ -77,6 +78,7 @@ class DataStore(RoomMemberStore, RoomStore, ApplicationServiceTransactionStore, EventsStore, ReceiptsStore, + EndToEndKeyStore, ): def __init__(self, hs): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py new file mode 100644 index 0000000000..99dc864e46 --- /dev/null +++ b/synapse/storage/end_to_end_keys.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class EndToEndKeyStore(SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): + return self._simple_upsert( + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": json_bytes, + } + ) + + def get_e2e_device_keys(self, query_list): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key json byte strings. + """ + def _get_e2e_device_keys(txn): + result = {} + for user_id, device_id in query_list: + user_result = result.setdefault(user_id, {}) + keyvalues = {"user_id": user_id} + if device_id: + keyvalues["device_id"] = device_id + rows = self._simple_select_list_txn( + txn, table="e2e_device_keys_json", + keyvalues=keyvalues, + retcols=["device_id", "key_json"] + ) + for row in rows: + user_result[row["device_id"]] = row["key_json"] + return result + return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + + def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): + def _add_e2e_one_time_keys(txn): + for (algorithm, key_id, json_bytes) in key_list: + self._simple_upsert_txn( + txn, table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": json_bytes, + } + ) + return self.runInteraction( + "add_e2e_one_time_keys", _add_e2e_one_time_keys + ) + + def count_e2e_one_time_keys(self, user_id, device_id): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + def _count_e2e_one_time_keys(txn): + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn.fetchall(): + result[algorithm] = key_count + return result + return self.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) + + def take_e2e_one_time_keys(self, query_list): + """Take a list of one time keys out of the database""" + def _take_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm)) + for key_id, key_json in txn.fetchall(): + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + return result + return self.runInteraction( + "take_e2e_one_time_keys", _take_e2e_one_time_keys + ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 5bdf497b93..940a5f7e08 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from _base import SQLBaseStore, cached from twisted.internet import defer @@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) + @cached() + @defer.inlineCallbacks + def get_all_server_verify_keys(self, server_name): + rows = yield self._simple_select_list( + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + }, + retcols=["key_id", "verify_key"], + desc="get_all_server_verify_keys", + ) + + defer.returnValue({ + row["key_id"]: decode_verify_key_bytes( + row["key_id"], str(row["verify_key"]) + ) + for row in rows + }) + @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): """Retrieve the NACL verification key for a given server for the given @@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - sql = ( - "SELECT key_id, verify_key FROM server_signature_keys" - " WHERE server_name = ?" - " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" - ) - - rows = yield self._execute_and_decode( - "get_server_verify_keys", sql, server_name, *key_ids - ) - - keys = [] - for row in rows: - key_id = row["key_id"] - key_bytes = row["verify_key"] - key = decode_verify_key_bytes(key_id, str(key_bytes)) - keys.append(key) - defer.returnValue(keys) + keys = yield self.get_all_server_verify_keys(server_name) + defer.returnValue({ + k: keys[k] + for k in key_ids + if k in keys and keys[k] + }) + @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. @@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - return self._simple_upsert( + yield self._simple_upsert( table="server_signature_keys", keyvalues={ "server_name": server_name, @@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) + self.get_all_server_verify_keys.invalidate(server_name) + def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): """Stores the JSON bytes for a set of keys from a server @@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore): "ts_valid_until_ms": ts_expires_ms, "key_json": buffer(key_json_bytes), }, + desc="store_server_keys_json", ) def get_server_keys_json(self, server_keys): diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql new file mode 100644 index 0000000000..8b4a380d11 --- /dev/null +++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql @@ -0,0 +1,34 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( + user_id TEXT NOT NULL, -- The user these keys are for. + device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. + ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. + key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. + CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) +); + + +CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( + user_id TEXT NOT NULL, -- The user this one-time key is for. + device_id TEXT NOT NULL, -- The device this one-time key is for. + algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. + key_json TEXT NOT NULL, -- The key as a JSON blob. + CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f2b17f29ea..d7844edee3 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -92,11 +92,11 @@ class StateStore(SQLBaseStore): defer.returnValue(dict(state_list)) @cached(num_args=1) - def _fetch_events_for_group(self, state_group, events): + def _fetch_events_for_group(self, key, events): return self._get_events( events, get_prev_content=False ).addCallback( - lambda evs: (state_group, evs) + lambda evs: (key, evs) ) def _store_state_groups_txn(self, txn, event, context): @@ -194,6 +194,65 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids): + def f(txn): + groups = set() + event_to_group = {} + for event_id in event_ids: + # TODO: Remove this loop. + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + event_to_group[event_id] = group + groups.add(group) + + group_to_state_ids = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + + group_to_state_ids[group] = state_ids + + return event_to_group, group_to_state_ids + + res = yield self.runInteraction( + "annotate_events_with_state_groups", + f, + ) + + event_to_group, group_to_state_ids = res + + state_list = yield defer.gatherResults( + [ + self._fetch_events_for_group(group, vals) + for group, vals in group_to_state_ids.items() + ], + consumeErrors=True, + ) + + state_dict = { + group: { + (ev.type, ev.state_key): ev + for ev in state + } + for group, state in state_list + } + + defer.returnValue([ + state_dict.get(event_to_group.get(event, None), None) + for event in event_ids + ]) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) |