diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 0d24aa7ac2..bfe6e61602 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -289,6 +289,7 @@ class Keyring(object):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
old_verify_keys = {}
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 69bc15ba75..e434847b45 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,6 +13,7 @@
# limitations under the License.
from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.http.servlet import parse_integer
from synapse.api.errors import SynapseError, Codes
from twisted.web.resource import Resource
@@ -44,7 +45,13 @@ class RemoteKey(Resource):
POST /_matrix/v2/query HTTP/1.1
Content-Type: application/json
{
- "server_keys": { "remote.server.example.com": ["a.key.id"] }
+ "server_keys": {
+ "remote.server.example.com": {
+ "a.key.id": {
+ "minimum_valid_until_ts": 1234567890123
+ }
+ }
+ }
}
Response:
@@ -96,10 +103,16 @@ class RemoteKey(Resource):
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
- query = {server: [None]}
+ query = {server: {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
- query = {server: [key_id]}
+ minimum_valid_until_ts = parse_integer(
+ request, "minimum_valid_until_ts"
+ )
+ arguments = {}
+ if minimum_valid_until_ts is not None:
+ arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
+ query = {server: {key_id: arguments}}
else:
raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
@@ -128,8 +141,11 @@ class RemoteKey(Resource):
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
+ if not key_ids:
+ key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
@@ -152,9 +168,44 @@ class RemoteKey(Resource):
if key_id is not None:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
- if (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
+ req_key = query.get(server_name, {}).get(key_id, {})
+ req_valid_until = req_key.get("minimum_valid_until_ts")
+ miss = False
+ if req_valid_until is not None:
+ if ts_valid_until_ms < req_valid_until:
+ logger.debug(
+ "Cached response for %r/%r is older than requested"
+ ": valid_until (%r) < minimum_valid_until (%r)",
+ server_name, key_id,
+ ts_valid_until_ms, req_valid_until
+ )
+ miss = True
+ else:
+ logger.debug(
+ "Cached response for %r/%r is newer than requested"
+ ": valid_until (%r) >= minimum_valid_until (%r)",
+ server_name, key_id,
+ ts_valid_until_ms, req_valid_until
+ )
+ elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
+ logger.debug(
+ "Cached response for %r/%r is too old"
+ ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
+ server_name, key_id,
+ ts_added_ms, ts_valid_until_ms, time_now_ms
+ )
# We more than half way through the lifetime of the
# response. We should fetch a fresh copy.
+ miss = True
+ else:
+ logger.debug(
+ "Cached response for %r/%r is still valid"
+ ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
+ server_name, key_id,
+ ts_added_ms, ts_valid_until_ms, time_now_ms
+ )
+
+ if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
json_results.add(bytes(most_recent_result["key_json"]))
else:
|