summary refs log tree commit diff
path: root/synapse/util/retryutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/retryutils.py')
-rw-r--r--synapse/util/retryutils.py69
1 files changed, 42 insertions, 27 deletions
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 129b47cd49..648d9a95a7 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -13,9 +13,13 @@
 # limitations under the License.
 import logging
 import random
+from types import TracebackType
+from typing import Any, Optional, Type
 
 import synapse.logging.context
 from synapse.api.errors import CodeMessageException
+from synapse.storage import DataStore
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62
 
 
 class NotRetryingDestination(Exception):
-    def __init__(self, retry_last_ts, retry_interval, destination):
+    def __init__(self, retry_last_ts: int, retry_interval: int, destination: str):
         """Raised by the limiter (and federation client) to indicate that we are
         are deliberately not attempting to contact a given server.
 
         Args:
-            retry_last_ts (int): the unix ts in milliseconds of our last attempt
+            retry_last_ts: the unix ts in milliseconds of our last attempt
                 to contact the server.  0 indicates that the last attempt was
                 successful or that we've never actually attempted to connect.
-            retry_interval (int): the time in milliseconds to wait until the next
+            retry_interval: the time in milliseconds to wait until the next
                 attempt.
-            destination (str): the domain in question
+            destination: the domain in question
         """
 
         msg = "Not retrying server %s." % (destination,)
@@ -51,7 +55,13 @@ class NotRetryingDestination(Exception):
         self.destination = destination
 
 
-async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
+async def get_retry_limiter(
+    destination: str,
+    clock: Clock,
+    store: DataStore,
+    ignore_backoff: bool = False,
+    **kwargs: Any,
+) -> "RetryDestinationLimiter":
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
     CodeMessageException with code < 500)
 
     Args:
-        destination (str): name of homeserver
-        clock (synapse.util.clock): timing source
-        store (synapse.storage.transactions.TransactionStore): datastore
-        ignore_backoff (bool): true to ignore the historical backoff data and
+        destination: name of homeserver
+        clock: timing source
+        store: datastore
+        ignore_backoff: true to ignore the historical backoff data and
             try the request anyway. We will still reset the retry_interval on success.
 
     Example usage:
@@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
 class RetryDestinationLimiter:
     def __init__(
         self,
-        destination,
-        clock,
-        store,
-        failure_ts,
-        retry_interval,
-        backoff_on_404=False,
-        backoff_on_failure=True,
+        destination: str,
+        clock: Clock,
+        store: DataStore,
+        failure_ts: Optional[int],
+        retry_interval: int,
+        backoff_on_404: bool = False,
+        backoff_on_failure: bool = True,
     ):
         """Marks the destination as "down" if an exception is thrown in the
         context, except for CodeMessageException with code < 500.
@@ -128,17 +138,17 @@ class RetryDestinationLimiter:
         If no exception is raised, marks the destination as "up".
 
         Args:
-            destination (str)
-            clock (Clock)
-            store (DataStore)
-            failure_ts (int|None): when this destination started failing (in ms since
+            destination
+            clock
+            store
+            failure_ts: when this destination started failing (in ms since
                 the epoch), or zero if the last request was successful
-            retry_interval (int): The next retry interval taken from the
+            retry_interval: The next retry interval taken from the
                 database in milliseconds, or zero if the last request was
                 successful.
-            backoff_on_404 (bool): Back off if we get a 404
+            backoff_on_404: Back off if we get a 404
 
-            backoff_on_failure (bool): set to False if we should not increase the
+            backoff_on_failure: set to False if we should not increase the
                 retry interval on a failure.
         """
         self.clock = clock
@@ -150,10 +160,15 @@ class RetryDestinationLimiter:
         self.backoff_on_404 = backoff_on_404
         self.backoff_on_failure = backoff_on_failure
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         pass
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         valid_err_code = False
         if exc_type is None:
             valid_err_code = True
@@ -161,7 +176,7 @@ class RetryDestinationLimiter:
             # avoid treating exceptions which don't derive from Exception as
             # failures; this is mostly so as not to catch defer._DefGen.
             valid_err_code = True
-        elif issubclass(exc_type, CodeMessageException):
+        elif isinstance(exc_val, CodeMessageException):
             # Some error codes are perfectly fine for some APIs, whereas other
             # APIs may expect to never received e.g. a 404. It's important to
             # handle 404 as some remote servers will return a 404 when the HS
@@ -216,7 +231,7 @@ class RetryDestinationLimiter:
             if self.failure_ts is None:
                 self.failure_ts = retry_last_ts
 
-        async def store_retry_timings():
+        async def store_retry_timings() -> None:
             try:
                 await self.store.set_destination_retry_timings(
                     self.destination,