diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..8c97f2af5c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+ async def set_e2e_fallback_keys(
+ self, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
+ # fallback_keys will usually only have one item in it, so using a for
+ # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+ # FIXME: make sure that only one key per algorithm is uploaded
+ for key_id, fallback_key in fallback_keys.items():
+ algorithm, key_id = key_id.split(":", 1)
+ await self.db_pool.simple_upsert(
+ "e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ desc="set_e2e_fallback_key",
+ )
+
+ @cached(max_entries=10000)
+ async def get_e2e_unused_fallback_key_types(
+ self, user_id: str, device_id: str
+ ) -> List[str]:
+ """Returns the fallback key types that have an unused key.
+
+ Args:
+ user_id: the user whose keys are being queried
+ device_id: the device whose keys are being queried
+
+ Returns:
+ a list of key types
+ """
+ return await self.db_pool.simple_select_onecol(
+ "e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+ retcol="algorithm",
+ desc="get_e2e_unused_fallback_key_types",
+ )
+
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
@@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
+ fallback_sql = (
+ "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
result = {}
delete = []
+ used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
+ otk_row = txn.fetchone()
+ if otk_row is not None:
+ key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
+ else:
+ # no one-time key available, so see if there's a fallback
+ # key
+ txn.execute(fallback_sql, (user_id, device_id, algorithm))
+ fallback_row = txn.fetchone()
+ if fallback_row is not None:
+ key_id, key_json, used = fallback_row
+ device_result[algorithm + ":" + key_id] = key_json
+ if not used:
+ used_fallbacks.append(
+ (user_id, device_id, algorithm, key_id)
+ )
+
+ # drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ # mark fallback keys as used
+ for user_id, device_id, algorithm, key_id in used_fallbacks:
+ self.db_pool.simple_update_txn(
+ txn,
+ "e2e_fallback_keys_json",
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ {"used": True},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
return result
return await self.db_pool.runInteraction(
@@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+ user_id TEXT NOT NULL, -- The user this fallback key is for.
+ device_id TEXT NOT NULL, -- The device this fallback key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+ CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
|