diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index db69bb7c06..a3f8d92d08 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging
import random
-from typing import Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -63,15 +66,19 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning:
"""Creates and verifies group attestations."""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.signing_key = hs.signing_key
async def verify_attestation(
- self, attestation, group_id, user_id, server_name=None
- ):
+ self,
+ attestation: JsonDict,
+ group_id: str,
+ user_id: str,
+ server_name: Optional[str] = None,
+ ) -> None:
"""Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's
@@ -100,16 +107,18 @@ class GroupAttestationSigning:
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
+ assert server_name is not None
await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation"
)
- def create_attestation(self, group_id, user_id):
+ def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
"""Create an attestation for the group_id and user_id with default
validity length.
"""
- validity_period = DEFAULT_ATTESTATION_LENGTH_MS
- validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+ validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
+ *DEFAULT_ATTESTATION_JITTER
+ )
valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json(
@@ -126,7 +135,7 @@ class GroupAttestationSigning:
class GroupAttestionRenewer:
"""Responsible for sending and receiving attestation updates."""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.assestations = hs.get_groups_attestation_signing()
@@ -139,7 +148,9 @@ class GroupAttestionRenewer:
self._start_renew_attestations, 30 * 60 * 1000
)
- async def on_renew_attestation(self, group_id, user_id, content):
+ async def on_renew_attestation(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""When a remote updates an attestation"""
attestation = content["attestation"]
@@ -154,10 +165,10 @@ class GroupAttestionRenewer:
return {}
- def _start_renew_attestations(self):
+ def _start_renew_attestations(self) -> None:
return run_as_background_process("renew_attestations", self._renew_attestations)
- async def _renew_attestations(self):
+ async def _renew_attestations(self) -> None:
"""Called periodically to check if we need to update any of our attestations"""
now = self.clock.time_msec()
@@ -166,7 +177,7 @@ class GroupAttestionRenewer:
now + UPDATE_ATTESTATION_TIME_MS
)
- async def _renew_attestation(group_user: Tuple[str, str]):
+ async def _renew_attestation(group_user: Tuple[str, str]) -> None:
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
|