diff options
Diffstat (limited to 'synapse/groups/attestations.py')
-rw-r--r-- | synapse/groups/attestations.py | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index db69bb7c06..a3f8d92d08 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like: import logging import random -from typing import Tuple +from typing import TYPE_CHECKING, Optional, Tuple from signedjson.sign import sign_json from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -63,15 +66,19 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 class GroupAttestationSigning: """Creates and verifies group attestations.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.keyring = hs.get_keyring() self.clock = hs.get_clock() self.server_name = hs.hostname self.signing_key = hs.signing_key async def verify_attestation( - self, attestation, group_id, user_id, server_name=None - ): + self, + attestation: JsonDict, + group_id: str, + user_id: str, + server_name: Optional[str] = None, + ) -> None: """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -100,16 +107,18 @@ class GroupAttestationSigning: if valid_until_ms < now: raise SynapseError(400, "Attestation expired") + assert server_name is not None await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) - def create_attestation(self, group_id, user_id): + def create_attestation(self, group_id: str, user_id: str) -> JsonDict: """Create an attestation for the group_id and user_id with default validity length. """ - validity_period = DEFAULT_ATTESTATION_LENGTH_MS - validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform( + *DEFAULT_ATTESTATION_JITTER + ) valid_until_ms = int(self.clock.time_msec() + validity_period) return sign_json( @@ -126,7 +135,7 @@ class GroupAttestationSigning: class GroupAttestionRenewer: """Responsible for sending and receiving attestation updates.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.assestations = hs.get_groups_attestation_signing() @@ -139,7 +148,9 @@ class GroupAttestionRenewer: self._start_renew_attestations, 30 * 60 * 1000 ) - async def on_renew_attestation(self, group_id, user_id, content): + async def on_renew_attestation( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """When a remote updates an attestation""" attestation = content["attestation"] @@ -154,10 +165,10 @@ class GroupAttestionRenewer: return {} - def _start_renew_attestations(self): + def _start_renew_attestations(self) -> None: return run_as_background_process("renew_attestations", self._renew_attestations) - async def _renew_attestations(self): + async def _renew_attestations(self) -> None: """Called periodically to check if we need to update any of our attestations""" now = self.clock.time_msec() @@ -166,7 +177,7 @@ class GroupAttestionRenewer: now + UPDATE_ATTESTATION_TIME_MS ) - async def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]) -> None: group_id, user_id = group_user try: if not self.is_mine_id(group_id): |