summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-05-04 17:57:17 +0100
committerErik Johnston <erik@matrix.org>2021-05-04 17:57:46 +0100
commitd4175abe52eee9c4738a0fc3a8984a17cb71e5fb (patch)
tree7527250f021b13d647d3bacada7522f7ae490fbb
parentFix remote resource (diff)
downloadsynapse-d4175abe52eee9c4738a0fc3a8984a17cb71e5fb.tar.xz
Allow fetching events
-rw-r--r--synapse/crypto/keyring.py72
-rw-r--r--synapse/federation/federation_base.py17
2 files changed, 59 insertions, 30 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 64aa1dc758..a8c8df2bad 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,7 +16,7 @@
 import abc
 import logging
 import urllib
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from signedjson.key import (
@@ -41,11 +41,9 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config.key import TrustedKeyServer
-from synapse.logging.context import (
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    run_in_background,
-)
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
@@ -72,8 +70,6 @@ class VerifyJsonRequest:
         minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
 
-        request_name: The name of the request.
-
         key_ids: The set of key_ids to that could be used to verify the JSON object
 
         key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
@@ -86,13 +82,32 @@ class VerifyJsonRequest:
     """
 
     server_name = attr.ib(type=str)
-    json_object = attr.ib(type=JsonDict)
+    json_object_callback = attr.ib(type=Callable[[], JsonDict])
     minimum_valid_until_ts = attr.ib(type=int)
-    request_name = attr.ib(type=str)
-    key_ids = attr.ib(init=False, type=List[str])
+    key_ids = attr.ib(type=List[str])
 
-    def __attrs_post_init__(self):
-        self.key_ids = signature_ids(self.json_object, self.server_name)
+    @staticmethod
+    def from_json_object(
+        server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
+    ):
+        key_ids = signature_ids(json_object, server_name)
+        return VerifyJsonRequest(
+            server_name, lambda: json_object, minimum_valid_until_ms, key_ids
+        )
+
+    @staticmethod
+    def from_event(
+        server_name: str,
+        minimum_valid_until_ms: int,
+        event: EventBase,
+    ):
+        key_ids = list(event.signatures.get(server_name, []))
+        return VerifyJsonRequest(
+            server_name,
+            lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+            minimum_valid_until_ms,
+            key_ids,
+        )
 
 
 class KeyLookupError(ValueError):
@@ -178,8 +193,10 @@ class Keyring:
         validity_time: int,
         request_name: str,
     ) -> defer.Deferred:
-        request = VerifyJsonRequest(
-            server_name, json_object, validity_time, request_name
+        request = VerifyJsonRequest.from_json_object(
+            server_name,
+            validity_time,
+            json_object,
         )
         return defer.ensureDeferred(self._verify_object(request))
 
@@ -189,14 +206,32 @@ class Keyring:
         return [
             defer.ensureDeferred(
                 self._verify_object(
-                    VerifyJsonRequest(
-                        server_name, json_object, validity_time, request_name
+                    VerifyJsonRequest.from_json_object(
+                        server_name,
+                        validity_time,
+                        json_object,
                     )
                 )
             )
             for server_name, json_object, validity_time, request_name in server_and_json
         ]
 
+    def verify_events_for_server(
+        self, server_and_json: Iterable[Tuple[str, EventBase, int]]
+    ) -> List[defer.Deferred]:
+        return [
+            defer.ensureDeferred(
+                self._verify_object(
+                    VerifyJsonRequest.from_event(
+                        server_name,
+                        validity_time,
+                        event,
+                    )
+                )
+            )
+            for server_name, event, validity_time in server_and_json
+        ]
+
     async def _verify_object(self, verify_request: VerifyJsonRequest):
         # TODO: Use a batching thing.
         with (await self._server_queue.queue(verify_request.server_name)):
@@ -239,8 +274,9 @@ class Keyring:
             for key_id in verify_request.key_ids:
                 verify_key = found_keys[key_id].verify_key
                 try:
+                    json_object = verify_request.json_object_callback()
                     verify_signed_json(
-                        verify_request.json_object,
+                        json_object,
                         verify_request.server_name,
                         verify_key,
                     )
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase:
         return deferreds
 
 
-class PduToCheckSig(
-    namedtuple(
-        "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
-    )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
     pass
 
 
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
     pdus_to_check = [
         PduToCheckSig(
             pdu=p,
-            redacted_pdu_json=prune_event(p).get_pdu_json(),
             sender_domain=get_domain_from_id(p.sender),
             deferreds=[],
         )
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
     # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
 
-    more_deferreds = keyring.verify_json_objects_for_server(
+    more_deferreds = keyring.verify_events_for_server(
         [
             (
                 p.sender_domain,
-                p.redacted_pdu_json,
+                p.pdu,
                 p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                p.pdu.event_id,
             )
             for p in pdus_to_check_sender
         ]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
             if p.sender_domain != get_domain_from_id(p.pdu.event_id)
         ]
 
-        more_deferreds = keyring.verify_json_objects_for_server(
+        more_deferreds = keyring.verify_events_for_server(
             [
                 (
                     get_domain_from_id(p.pdu.event_id),
-                    p.redacted_pdu_json,
+                    p.pdu,
                     p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                    p.pdu.event_id,
                 )
                 for p in pdus_to_check_event_id
             ]