summary refs log tree commit diff
path: root/synapse/rest/key/v2
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/key/v2')
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py59
1 files changed, 55 insertions, 4 deletions
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 69bc15ba75..e434847b45 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.http.servlet import parse_integer
 from synapse.api.errors import SynapseError, Codes
 
 from twisted.web.resource import Resource
@@ -44,7 +45,13 @@ class RemoteKey(Resource):
     POST /_matrix/v2/query HTTP/1.1
     Content-Type: application/json
     {
-        "server_keys": { "remote.server.example.com": ["a.key.id"] }
+        "server_keys": {
+            "remote.server.example.com": {
+                "a.key.id": {
+                    "minimum_valid_until_ts": 1234567890123
+                }
+            }
+        }
     }
 
     Response:
@@ -96,10 +103,16 @@ class RemoteKey(Resource):
     def async_render_GET(self, request):
         if len(request.postpath) == 1:
             server, = request.postpath
-            query = {server: [None]}
+            query = {server: {}}
         elif len(request.postpath) == 2:
             server, key_id = request.postpath
-            query = {server: [key_id]}
+            minimum_valid_until_ts = parse_integer(
+                request, "minimum_valid_until_ts"
+            )
+            arguments = {}
+            if minimum_valid_until_ts is not None:
+                arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
+            query = {server: {key_id: arguments}}
         else:
             raise SynapseError(
                 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
@@ -128,8 +141,11 @@ class RemoteKey(Resource):
 
     @defer.inlineCallbacks
     def query_keys(self, request, query, query_remote_on_cache_miss=False):
+        logger.info("Handling query for keys %r", query)
         store_queries = []
         for server_name, key_ids in query.items():
+            if not key_ids:
+                key_ids = (None,)
             for key_id in key_ids:
                 store_queries.append((server_name, key_id, None))
 
@@ -152,9 +168,44 @@ class RemoteKey(Resource):
             if key_id is not None:
                 ts_added_ms, most_recent_result = max(results)
                 ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
-                if (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
+                req_key = query.get(server_name, {}).get(key_id, {})
+                req_valid_until = req_key.get("minimum_valid_until_ts")
+                miss = False
+                if req_valid_until is not None:
+                    if ts_valid_until_ms < req_valid_until:
+                        logger.debug(
+                            "Cached response for %r/%r is older than requested"
+                            ": valid_until (%r) < minimum_valid_until (%r)",
+                            server_name, key_id,
+                            ts_valid_until_ms, req_valid_until
+                        )
+                        miss = True
+                    else:
+                        logger.debug(
+                            "Cached response for %r/%r is newer than requested"
+                            ": valid_until (%r) >= minimum_valid_until (%r)",
+                            server_name, key_id,
+                            ts_valid_until_ms, req_valid_until
+                        )
+                elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
+                    logger.debug(
+                        "Cached response for %r/%r is too old"
+                        ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
+                        server_name, key_id,
+                        ts_added_ms, ts_valid_until_ms, time_now_ms
+                    )
                     # We more than half way through the lifetime of the
                     # response. We should fetch a fresh copy.
+                    miss = True
+                else:
+                    logger.debug(
+                        "Cached response for %r/%r is still valid"
+                        ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
+                        server_name, key_id,
+                        ts_added_ms, ts_valid_until_ms, time_now_ms
+                    )
+
+                if miss:
                     cache_misses.setdefault(server_name, set()).add(key_id)
                 json_results.add(bytes(most_recent_result["key_json"]))
             else: