From b5770f89478509b5c4c1a610753989fbe29e35e7 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 6 Jul 2015 18:46:47 +0100 Subject: Add store for client end to end keys --- synapse/storage/end_to_end_keys.py | 133 +++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 synapse/storage/end_to_end_keys.py (limited to 'synapse/storage/end_to_end_keys.py') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py new file mode 100644 index 0000000000..936a64669c --- /dev/null +++ b/synapse/storage/end_to_end_keys.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class EndToEndKeyStore(SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): + return self._simple_upsert( + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": json_bytes, + } + ) + + def get_e2e_device_keys(self, query_list): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key json byte strings. + """ + def _get_e2e_device_keys(txn): + result = {} + for user_id, device_id in query_list: + user_result = result.setdefault(user_id, {}) + keyvalues = {"user_id": user_id} + if device_id: + keyvalues["device_id"] = device_id + rows = self._simple_select_list_txn( + txn, table="e2e_device_keys_json", + keyvalues=keyvalues, + retcols=["device_id", "key_json"] + ) + for row in rows: + user_result[row["device_id"]] = row["key_json"] + return result + return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + + def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, + key_list): + def _add_e2e_one_time_keys(txn): + for (algorithm, key_id, json_bytes) in key_list: + self._simple_upsert_txn( + txn, table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + values={ + "ts_added_ms": time_now, + "valid_until_ms": valid_until, + "key_json": json_bytes, + } + ) + return self.runInteraction( + "add_e2e_one_time_keys", _add_e2e_one_time_keys + ) + + def count_e2e_one_time_keys(self, user_id, device_id, time_now): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + def _count_e2e_one_time_keys(txn): + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" + ) + txn.execute(sql, (user_id, device_id, time_now)) + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn.fetchall(): + result[algorithm] = key_count + return result + return self.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) + + def take_e2e_one_time_keys(self, query_list, time_now): + """Take a list of one time keys out of the database""" + def _take_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND valid_until_ms > ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm, time_now)) + for key_id, key_json in txn.fetchall(): + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + return result + return self.runInteraction( + "take_e2e_one_time_keys", _take_e2e_one_time_keys + ) -- cgit 1.5.1 From 8fb79eeea4b1c388771785024b79e84b4206fc24 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 8 Jul 2015 17:04:29 +0100 Subject: Only remove one time keys when new one time keys are added --- synapse/storage/end_to_end_keys.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'synapse/storage/end_to_end_keys.py') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 936a64669c..b3cede37e3 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -58,6 +58,11 @@ class EndToEndKeyStore(SQLBaseStore): def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, key_list): def _add_e2e_one_time_keys(txn): + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" + ) + txn.execute(sql, (user_id, device_id, time_now)) for (algorithm, key_id, json_bytes) in key_list: self._simple_upsert_txn( txn, table="e2e_one_time_keys_json", @@ -83,17 +88,12 @@ class EndToEndKeyStore(SQLBaseStore): Dict mapping from algorithm to number of keys for that algorithm. """ def _count_e2e_one_time_keys(txn): - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" - ) - txn.execute(sql, (user_id, device_id, time_now)) sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ?" + " WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?" " GROUP BY algorithm" ) - txn.execute(sql, (user_id, device_id)) + txn.execute(sql, (user_id, device_id, time_now)) result = {} for algorithm, key_count in txn.fetchall(): result[algorithm] = key_count -- cgit 1.5.1 From bf0d59ed30b63c6a355e7b3f2a74a26181fd6893 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 9 Jul 2015 14:04:03 +0100 Subject: Don't bother with a timeout for one time keys on the server. --- synapse/rest/client/v2_alpha/keys.py | 25 ++++++---------------- synapse/storage/end_to_end_keys.py | 20 ++++++----------- .../storage/schema/delta/21/end_to_end_keys.sql | 1 - 3 files changed, 13 insertions(+), 33 deletions(-) (limited to 'synapse/storage/end_to_end_keys.py') diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 3bb4ad64f3..4b617c2519 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -50,7 +50,6 @@ class KeyUploadServlet(RestServlet): "one_time_keys": { ":": "" }, - "one_time_keys_valid_for": , } """ PATTERN = client_v2_pattern("/keys/upload/(?P[^/]*)") @@ -87,13 +86,10 @@ class KeyUploadServlet(RestServlet): ) one_time_keys = body.get("one_time_keys", None) - one_time_keys_valid_for = body.get("one_time_keys_valid_for", None) if one_time_keys: - valid_until = int(one_time_keys_valid_for) + time_now logger.info( - "Adding %d one_time_keys for device %r for user %r at %d" - " valid_until %d", - len(one_time_keys), device_id, user_id, time_now, valid_until + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now ) key_list = [] for key_id, key_json in one_time_keys.items(): @@ -103,23 +99,18 @@ class KeyUploadServlet(RestServlet): )) yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, valid_until, key_list + user_id, device_id, time_now, key_list ) - result = yield self.store.count_e2e_one_time_keys( - user_id, device_id, time_now - ) + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @defer.inlineCallbacks def on_GET(self, request, device_id): auth_user, client_info = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() - time_now = self.clock.time_msec() - result = yield self.store.count_e2e_one_time_keys( - user_id, device_id, time_now - ) + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @@ -249,9 +240,8 @@ class OneTimeKeyServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - time_now = self.clock.time_msec() results = yield self.store.take_e2e_one_time_keys( - [(user_id, device_id, algorithm)], time_now + [(user_id, device_id, algorithm)] ) defer.returnValue(self.json_result(request, results)) @@ -266,8 +256,7 @@ class OneTimeKeyServlet(RestServlet): for user_id, device_keys in body.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) - time_now = self.clock.time_msec() - results = yield self.store.take_e2e_one_time_keys(query, time_now) + results = yield self.store.take_e2e_one_time_keys(query) defer.returnValue(self.json_result(request, results)) def json_result(self, request, results): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index b3cede37e3..99dc864e46 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -55,14 +55,8 @@ class EndToEndKeyStore(SQLBaseStore): return result return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) - def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, - key_list): + def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" - ) - txn.execute(sql, (user_id, device_id, time_now)) for (algorithm, key_id, json_bytes) in key_list: self._simple_upsert_txn( txn, table="e2e_one_time_keys_json", @@ -74,7 +68,6 @@ class EndToEndKeyStore(SQLBaseStore): }, values={ "ts_added_ms": time_now, - "valid_until_ms": valid_until, "key_json": json_bytes, } ) @@ -82,7 +75,7 @@ class EndToEndKeyStore(SQLBaseStore): "add_e2e_one_time_keys", _add_e2e_one_time_keys ) - def count_e2e_one_time_keys(self, user_id, device_id, time_now): + def count_e2e_one_time_keys(self, user_id, device_id): """ Count the number of one time keys the server has for a device Returns: Dict mapping from algorithm to number of keys for that algorithm. @@ -90,10 +83,10 @@ class EndToEndKeyStore(SQLBaseStore): def _count_e2e_one_time_keys(txn): sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?" + " WHERE user_id = ? AND device_id = ?" " GROUP BY algorithm" ) - txn.execute(sql, (user_id, device_id, time_now)) + txn.execute(sql, (user_id, device_id)) result = {} for algorithm, key_count in txn.fetchall(): result[algorithm] = key_count @@ -102,13 +95,12 @@ class EndToEndKeyStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - def take_e2e_one_time_keys(self, query_list, time_now): + def take_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" def _take_e2e_one_time_keys(txn): sql = ( "SELECT key_id, key_json FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND valid_until_ms > ?" " LIMIT 1" ) result = {} @@ -116,7 +108,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm, time_now)) + txn.execute(sql, (user_id, device_id, algorithm)) for key_id, key_json in txn.fetchall(): device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql index 107d2e67c2..95e27eb7ea 100644 --- a/synapse/storage/schema/delta/21/end_to_end_keys.sql +++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql @@ -29,7 +29,6 @@ CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. - valid_until_ms BIGINT NOT NULL, -- When this key is valid until. key_json TEXT NOT NULL, -- The key as a JSON blob. CONSTRAINT uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); -- cgit 1.5.1 From 3b5823c74d5bffc68068284145cc78a33476ac84 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 14 Jul 2015 13:08:33 +0100 Subject: s/take/claim/ for end to end key APIs --- synapse/rest/client/v2_alpha/keys.py | 10 +++++----- synapse/storage/end_to_end_keys.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'synapse/storage/end_to_end_keys.py') diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f031267751..9a0c842283 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -207,9 +207,9 @@ class KeyQueryServlet(RestServlet): class OneTimeKeyServlet(RestServlet): """ - GET /keys/take/// HTTP/1.1 + GET /keys/claim/// HTTP/1.1 - POST /keys/take HTTP/1.1 + POST /keys/claim HTTP/1.1 { "one_time_keys": { "": { @@ -226,7 +226,7 @@ class OneTimeKeyServlet(RestServlet): """ PATTERN = client_v2_pattern( - "/keys/take(?:/?|(?:/" + "/keys/claim(?:/?|(?:/" "(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" ")?)" ) @@ -240,7 +240,7 @@ class OneTimeKeyServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - results = yield self.store.take_e2e_one_time_keys( + results = yield self.store.claim_e2e_one_time_keys( [(user_id, device_id, algorithm)] ) defer.returnValue(self.json_result(request, results)) @@ -256,7 +256,7 @@ class OneTimeKeyServlet(RestServlet): for user_id, device_keys in body.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) - results = yield self.store.take_e2e_one_time_keys(query) + results = yield self.store.claim_e2e_one_time_keys(query) defer.returnValue(self.json_result(request, results)) def json_result(self, request, results): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 99dc864e46..325740d7d0 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -95,9 +95,9 @@ class EndToEndKeyStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - def take_e2e_one_time_keys(self, query_list): + def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" - def _take_e2e_one_time_keys(txn): + def _claim_e2e_one_time_keys(txn): sql = ( "SELECT key_id, key_json FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -121,5 +121,5 @@ class EndToEndKeyStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, algorithm, key_id)) return result return self.runInteraction( - "take_e2e_one_time_keys", _take_e2e_one_time_keys + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) -- cgit 1.5.1 From cf7a40b08a381ee5715f915effd63dfe241a8d61 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 21 Jul 2015 16:08:00 -0700 Subject: I think this was what was intended... --- synapse/storage/end_to_end_keys.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage/end_to_end_keys.py') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 325740d7d0..69287c43df 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -52,6 +52,7 @@ class EndToEndKeyStore(SQLBaseStore): ) for row in rows: user_result[row["device_id"]] = row["key_json"] + result[user_id] = user_result return result return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) -- cgit 1.5.1 From 20c0324e9c0ebe569d30bd6541bc8b5c9d3c7ae2 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 21 Jul 2015 16:21:37 -0700 Subject: Dodesn't seem to make any difference: guess it does work with the object reference --- synapse/storage/end_to_end_keys.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage/end_to_end_keys.py') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 69287c43df..325740d7d0 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -52,7 +52,6 @@ class EndToEndKeyStore(SQLBaseStore): ) for row in rows: user_result[row["device_id"]] = row["key_json"] - result[user_id] = user_result return result return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) -- cgit 1.5.1