diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d81887f5b7..7ac93aef33 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -62,18 +62,7 @@ class E2eKeysHandler(object):
...
}
}
- },
- "attestations": [
- "user_id": "<user_id>",
- "device_id": "<device_id>",
- "keys": {
- "ed25519": "<key_base64>"
- },
- "state": "<verified or revoked>",
- "signatures": {
- "<algorithm>:<device_id>": "<signature_base64>"
- }
- ]
+ }
}
"""
device_keys_query = query_body.get("device_keys", {})
@@ -93,13 +82,11 @@ class E2eKeysHandler(object):
# First get local devices.
failures = {}
results = {}
- attestations = []
if local_query:
- local_result = yield self.query_local_devices(local_query)
+ local_result = yield self.query_local_devices(local_query, req_user_id)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
- attestations = yield self.query_local_attestations(req_user_id, local_query)
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
@@ -121,11 +108,14 @@ class E2eKeysHandler(object):
for device_id, device in iteritems(devices):
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
+ attestations = device.get("attestations", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
+ if attestations:
+ unsigned["attestations"] = attestations
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
@@ -158,16 +148,16 @@ class E2eKeysHandler(object):
defer.returnValue({
"device_keys": results, "failures": failures,
- "attestations": attestations
})
@defer.inlineCallbacks
- def query_local_devices(self, query):
+ def query_local_devices(self, query, req_user_id=None):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
+ req_user_id: the user requesting the devices
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
@@ -192,7 +182,9 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = yield self.store.get_e2e_device_keys(local_query)
+ results = yield self.store.get_e2e_device_keys(
+ local_query, req_user_id=req_user_id,
+ )
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
@@ -201,36 +193,16 @@ class E2eKeysHandler(object):
r = dict(device_info["keys"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
+ attestations = device_info.get("attestations", None)
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
+ if attestations is not None:
+ r["unsigned"]["attestations"] = attestations
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
- def query_local_attestations(self, req_user_id, query):
- local_query = []
-
- for user_id, device_ids in query.items():
- # we use UserID.from_string to catch invalid user ids
- if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
- raise SynapseError(400, "Not a user here")
-
- if not device_ids:
- local_query.append((user_id, None))
- else:
- for device_id in device_ids:
- local_query.append((user_id, device_id))
-
- results = yield self.store.get_e2e_attestations(req_user_id, local_query)
-
- # FIXME: combine signatures of the same payload
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index ff3ce84eb1..f08a4efbfc 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -66,7 +66,7 @@ class EndToEndKeyStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
- include_deleted_devices=False,
+ include_deleted_devices=False, req_user_id=None,
):
"""Fetch a list of device keys.
Args:
@@ -76,6 +76,7 @@ class EndToEndKeyStore(SQLBaseStore):
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
+ req_user_id: The user requesting the device list
Returns:
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
@@ -85,7 +86,7 @@ class EndToEndKeyStore(SQLBaseStore):
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
- query_list, include_all_devices, include_deleted_devices,
+ query_list, include_all_devices, include_deleted_devices, req_user_id,
)
for user_id, device_keys in iteritems(results):
@@ -96,7 +97,7 @@ class EndToEndKeyStore(SQLBaseStore):
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False,
- include_deleted_devices=False,
+ include_deleted_devices=False, req_user_id=None,
):
query_clauses = []
query_params = []
@@ -142,6 +143,16 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
+ attestations = self._get_e2e_attestations_txn(txn, req_user_id, query_list)
+
+ for attestation in attestations:
+ user_id = attestation["user_id"]
+ device_id = attestation["device_id"]
+ # FIXME: combine signatures of the same payload?
+ if user_id in result and device_id in result[user_id]:
+ result[user_id][device_id].setdefault("attestations", []) \
+ .append(attestation)
+
return result
@defer.inlineCallbacks
@@ -305,35 +316,35 @@ class EndToEndKeyStore(SQLBaseStore):
"add_e2e_attestations", add_e2e_attestations_txn
)
- @defer.inlineCallbacks
- def get_e2e_attestations(self, from_user_id, query_list):
- def get_e2e_attestations_txn(txn):
- query_clauses = []
- query_params = []
+ def _get_e2e_attestations_txn(self, txn, from_user_id, query_list):
+
+ query_clauses = []
+ query_params = []
- for (user_id, device_id) in query_list:
- query_clause = "(from_user_id = ? OR from_user_id = ?) AND user_id = ?"
+ for (user_id, device_id) in query_list:
+ if from_user_id:
+ query_clause = "(from_user_id = ? OR from_user_id = ?)"
query_params.append(from_user_id)
- query_params.append(user_id)
- query_params.append(user_id)
+ else:
+ query_clause = "(from_user_id = ?)"
+ query_params.append(user_id)
- if device_id is not None:
- query_clause += " AND device_id = ?"
- query_params.append(device_id)
+ query_clause += " AND user_id = ?"
+ query_params.append(user_id)
- query_clauses.append(query_clause)
+ if device_id is not None:
+ query_clause += " AND device_id = ?"
+ query_params.append(device_id)
- sql = (
- "SELECT attestation "
- " FROM e2e_attestations "
- " WHERE %s"
- ) % (
- " OR ".join("(" + q + ")" for q in query_clauses)
- )
+ query_clauses.append(query_clause)
- txn.execute(sql, query_params)
- return [json.loads(row[0]) for row in txn]
- results = yield self.runInteraction(
- "get_e2e_attestations", get_e2e_attestations_txn
+ sql = (
+ "SELECT attestation "
+ " FROM e2e_attestations "
+ " WHERE %s"
+ ) % (
+ " OR ".join("(" + q + ")" for q in query_clauses)
)
- defer.returnValue(results)
+
+ txn.execute(sql, query_params)
+ return [json.loads(row[0]) for row in txn]
|