From 7eb6e39a8fe9d42a411cefd905cf2caa29896923 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 4 Mar 2021 14:44:22 +0000 Subject: Record the SSO Auth Provider in the login token (#9510) This great big stack of commits is a a whole load of hoop-jumping to make it easier to store additional values in login tokens, and then to actually store the SSO Identity Provider in the login token. (Making use of that data will follow in a subsequent PR.) --- synapse/util/macaroons.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 synapse/util/macaroons.py (limited to 'synapse/util') diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py new file mode 100644 index 0000000000..12cdd53327 --- /dev/null +++ b/synapse/util/macaroons.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for manipulating macaroons""" + +from typing import Callable, Optional + +import pymacaroons +from pymacaroons.exceptions import MacaroonVerificationFailedException + + +def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Checks that there is exactly one caveat of the form "key = " in the macaroon, + and returns the extracted value. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + MacaroonVerificationFailedException: if there are conflicting values for the + caveat in the macaroon, or if the caveat was not found in the macaroon. + """ + prefix = key + " = " + result = None # type: Optional[str] + for caveat in macaroon.caveats: + if not caveat.caveat_id.startswith(prefix): + continue + + val = caveat.caveat_id[len(prefix) :] + + if result is None: + # first time we found this caveat: record the value + result = val + elif val != result: + # on subsequent occurrences, raise if the value is different. + raise MacaroonVerificationFailedException( + "Conflicting values for caveat " + key + ) + + if result is not None: + return result + + # If the caveat is not there, we raise a MacaroonVerificationFailedException. + # Note that it is insecure to generate a macaroon without all the caveats you + # might need (because there is nothing stopping people from adding extra caveats), + # so if the caveat isn't there, something odd must be going on. + raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,)) + + +def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None: + """Make a macaroon verifier which accepts 'time' caveats + + Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to + the given macaroon verifier. + + Args: + v: the macaroon verifier + get_time_ms: a callable which will return the timestamp after which the caveat + should be considered expired. Normally the current time. + """ + + def verify_expiry_caveat(caveat: str): + time_msec = get_time_ms() + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + return time_msec < expiry + + v.satisfy_general(verify_expiry_caveat) -- cgit 1.5.1 From d6196efafcc312472464c882ab630bc3fbf7bd37 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 8 Mar 2021 20:00:07 +0100 Subject: Add ResponseCache tests. (#9458) --- changelog.d/9458.misc | 1 + synapse/appservice/api.py | 2 +- synapse/federation/federation_server.py | 13 ++-- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_list.py | 4 +- synapse/handlers/sync.py | 2 +- synapse/replication/http/_base.py | 9 ++- synapse/util/caches/response_cache.py | 10 +-- tests/util/caches/test_responsecache.py | 131 ++++++++++++++++++++++++++++++++ 10 files changed, 156 insertions(+), 20 deletions(-) create mode 100644 changelog.d/9458.misc create mode 100644 tests/util/caches/test_responsecache.py (limited to 'synapse/util') diff --git a/changelog.d/9458.misc b/changelog.d/9458.misc new file mode 100644 index 0000000000..8ceeed1352 --- /dev/null +++ b/changelog.d/9458.misc @@ -0,0 +1 @@ +Add tests to ResponseCache. \ No newline at end of file diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 93c2aabcca..9d3bbe3b8b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient): self.clock = hs.get_clock() self.protocol_meta_cache = ResponseCache( - hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS + hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) # type: ResponseCache[Tuple[str, str]] async def query_user(self, service, user_id): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 362895bf42..7657697bfa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,6 +22,7 @@ from typing import ( Awaitable, Callable, Dict, + Iterable, List, Optional, Tuple, @@ -98,7 +99,7 @@ last_pdu_ts_metric = Gauge( class FederationServer(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() @@ -118,7 +119,7 @@ class FederationServer(FederationBase): # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( - hs, "fed_txn_handler", timeout_ms=30000 + hs.get_clock(), "fed_txn_handler", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -128,10 +129,10 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache( - hs, "state_resp", timeout_ms=30000 + hs.get_clock(), "state_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( - hs, "state_ids_resp", timeout_ms=30000 + hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( @@ -453,7 +454,9 @@ class FederationServer(FederationBase): self, room_id: str, event_id: str ) -> Dict[str, list]: if event_id: - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + pdus = await self.handler.get_state_for_pdu( + room_id, event_id + ) # type: Iterable[EventBase] else: pdus = (await self.state.get_current_state(room_id)).values() diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 71a5076672..13f8152283 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler): self.clock = hs.get_clock() self.validator = EventValidator() self.snapshot_cache = ResponseCache( - hs, "initial_sync_cache" + hs.get_clock(), "initial_sync_cache" ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a488df10d6..4b3d0d72e3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler): # succession, only process the first attempt and return its result to # subsequent requests self._upgrade_response_cache = ResponseCache( - hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS + hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS ) # type: ResponseCache[Tuple[str, str]] self._server_notices_mxid = hs.config.server_notices_mxid diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 14f14db449..8bfc46c654 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler): super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache( - hs, "room_list" + hs.get_clock(), "room_list" ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] self.remote_response_cache = ResponseCache( - hs, "remote_room_list", timeout_ms=30 * 1000 + hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] async def get_local_public_room_list( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e8ed7b33f..f50257cd57 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -244,7 +244,7 @@ class SyncHandler: self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.response_cache = ResponseCache( - hs, "sync" + hs.get_clock(), "sync" ) # type: ResponseCache[Tuple[Any, ...]] self.state = hs.get_state_handler() self.auth = hs.get_auth() diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 8a3f113e76..b7aa0c280f 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -18,7 +18,7 @@ import logging import re import urllib from inspect import signature -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) _pending_outgoing_requests = Gauge( @@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): CACHE = True RETRY_ON_TIMEOUT = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 + hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 32228f42ee..46ea8e0964 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,17 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.util import Clock from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - logger = logging.getLogger(__name__) T = TypeVar("T") @@ -37,11 +35,11 @@ class ResponseCache(Generic[T]): used rather than trying to compute a new response. """ - def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): # Requests that haven't finished yet. self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] - self.clock = hs.get_clock() + self.clock = clock self.timeout_sec = timeout_ms / 1000.0 self._name = name diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_responsecache.py new file mode 100644 index 0000000000..f9a187b8de --- /dev/null +++ b/tests/util/caches/test_responsecache.py @@ -0,0 +1,131 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches.response_cache import ResponseCache + +from tests.server import get_clock +from tests.unittest import TestCase + + +class DeferredCacheTestCase(TestCase): + """ + A TestCase class for ResponseCache. + + The test-case function naming has some logic to it in it's parts, here's some notes about it: + wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available + (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.) + expire: Denotes tests that test expiry after assured existence. + (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) + """ + + def setUp(self): + self.reactor, self.clock = get_clock() + + def with_cache(self, name: str, ms: int = 0) -> ResponseCache: + return ResponseCache(self.clock, name, timeout_ms=ms) + + @staticmethod + async def instant_return(o: str) -> str: + return o + + async def delayed_return(self, o: str) -> str: + await self.clock.sleep(1) + return o + + def test_cache_hit(self): + cache = self.with_cache("keeping_cache", ms=9001) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.instant_return, expected_result) + + self.assertEqual( + expected_result, + self.successResultOf(wrap_d), + "initial wrap result should be the same", + ) + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should have the result", + ) + + def test_cache_miss(self): + cache = self.with_cache("trashing_cache", ms=0) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.instant_return, expected_result) + + self.assertEqual( + expected_result, + self.successResultOf(wrap_d), + "initial wrap result should be the same", + ) + self.assertIsNone(cache.get(0), "cache should not have the result now") + + def test_cache_expire(self): + cache = self.with_cache("short_cache", ms=1000) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.instant_return, expected_result) + + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should still have the result", + ) + + # cache eviction timer is handled + self.reactor.pump((2,)) + + self.assertIsNone(cache.get(0), "cache should not have the result now") + + def test_cache_wait_hit(self): + cache = self.with_cache("neutral_cache") + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.delayed_return, expected_result) + self.assertNoResult(wrap_d) + + # function wakes up, returns result + self.reactor.pump((2,)) + + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + + def test_cache_wait_expire(self): + cache = self.with_cache("medium_cache", ms=3000) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.delayed_return, expected_result) + self.assertNoResult(wrap_d) + + # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2 + self.reactor.pump((1, 1)) + + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should still have the result", + ) + + # (1 + 1 + 2) > 3.0, cache eviction timer is handled + self.reactor.pump((2,)) + + self.assertIsNone(cache.get(0), "cache should not have the result now") -- cgit 1.5.1 From 9898470e7d7b1b2f3c0dfa7d07832ec4662221da Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 9 Mar 2021 12:09:31 +0100 Subject: Add logging to ObservableDeferred callbacks (#9523) --- changelog.d/9523.misc | 1 + synapse/util/async_helpers.py | 26 ++++++++++++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) create mode 100644 changelog.d/9523.misc (limited to 'synapse/util') diff --git a/changelog.d/9523.misc b/changelog.d/9523.misc new file mode 100644 index 0000000000..f03e939efb --- /dev/null +++ b/changelog.d/9523.misc @@ -0,0 +1 @@ +Add extra logging to ObservableDeferred when callbacks throw exceptions. \ No newline at end of file diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 719e35b78d..f33c115844 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -76,11 +76,16 @@ class ObservableDeferred: def callback(r): object.__setattr__(self, "_result", (True, r)) while self._observers: + observer = self._observers.pop() try: - # TODO: Handle errors here. - self._observers.pop().callback(r) - except Exception: - pass + observer.callback(r) + except Exception as e: + logger.exception( + "%r threw an exception on .callback(%r), ignoring...", + observer, + r, + exc_info=e, + ) return r def errback(f): @@ -90,11 +95,16 @@ class ObservableDeferred: # traces when we `await` on one of the observer deferreds. f.value.__failure__ = f + observer = self._observers.pop() try: - # TODO: Handle errors here. - self._observers.pop().errback(f) - except Exception: - pass + observer.errback(f) + except Exception as e: + logger.exception( + "%r threw an exception on .errback(%r), ignoring...", + observer, + f, + exc_info=e, + ) if consumeErrors: return None -- cgit 1.5.1