diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 719e35b78d..f33c115844 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -76,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().callback(r)
- except Exception:
- pass
+ observer.callback(r)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .callback(%r), ignoring...",
+ observer,
+ r,
+ exc_info=e,
+ )
return r
def errback(f):
@@ -90,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().errback(f)
- except Exception:
- pass
+ observer.errback(f)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .errback(%r), ignoring...",
+ observer,
+ f,
+ exc_info=e,
+ )
if consumeErrors:
return None
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42ee..46ea8e0964 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
- def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
- self.clock = hs.get_clock()
+ self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+ and returns the extracted value.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ MacaroonVerificationFailedException: if there are conflicting values for the
+ caveat in the macaroon, or if the caveat was not found in the macaroon.
+ """
+ prefix = key + " = "
+ result = None # type: Optional[str]
+ for caveat in macaroon.caveats:
+ if not caveat.caveat_id.startswith(prefix):
+ continue
+
+ val = caveat.caveat_id[len(prefix) :]
+
+ if result is None:
+ # first time we found this caveat: record the value
+ result = val
+ elif val != result:
+ # on subsequent occurrences, raise if the value is different.
+ raise MacaroonVerificationFailedException(
+ "Conflicting values for caveat " + key
+ )
+
+ if result is not None:
+ return result
+
+ # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+ # Note that it is insecure to generate a macaroon without all the caveats you
+ # might need (because there is nothing stopping people from adding extra caveats),
+ # so if the caveat isn't there, something odd must be going on.
+ raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+ """Make a macaroon verifier which accepts 'time' caveats
+
+ Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+ the given macaroon verifier.
+
+ Args:
+ v: the macaroon verifier
+ get_time_ms: a callable which will return the timestamp after which the caveat
+ should be considered expired. Normally the current time.
+ """
+
+ def verify_expiry_caveat(caveat: str):
+ time_msec = get_time_ms()
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ return time_msec < expiry
+
+ v.satisfy_general(verify_expiry_caveat)
|