summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/client/transactions.py22
-rw-r--r--tests/rest/client/test_transactions.py54
2 files changed, 67 insertions, 9 deletions
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index fceca2edeb..93ce0f5348 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -87,19 +87,23 @@ class HttpTransactionCache(object):
 
         deferred = fn(*args, **kwargs)
 
-        # if the request fails with a Twisted failure, remove it
-        # from the transaction map. This is done to ensure that we don't
-        # cache transient errors like rate-limiting errors, etc.
+        observable = ObservableDeferred(deferred, consumeErrors=False)
+        self.transactions[txn_key] = (observable, self.clock.time_msec())
+
+        # if the request fails with an exception, remove it from the
+        # transaction map. This is done to ensure that we don't cache
+        # transient errors like rate-limiting errors, etc.
+        #
+        # (make sure we add this errback *after* adding the key above, in case
+        # the deferred has already failed and is running errbacks
+        # synchronously)
         def remove_from_map(err):
             self.transactions.pop(txn_key, None)
-            return err
+            # we deliberately do not propagate the error any further, as we
+            # expect the observers to have reported it.
+
         deferred.addErrback(remove_from_map)
 
-        # We don't add any other errbacks to the raw deferred, so we ask
-        # ObservableDeferred to swallow the error. This is fine as the error will
-        # still be reported to the observers.
-        observable = ObservableDeferred(deferred, consumeErrors=True)
-        self.transactions[txn_key] = (observable, self.clock.time_msec())
         return observable.observe()
 
     def _cleanup(self):
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index d7cea30260..b650a7772b 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -2,6 +2,8 @@ from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
 from twisted.internet import defer
 from mock import Mock, call
+
+from synapse.util.logcontext import LoggingContext
 from tests import unittest
 from tests.utils import MockClock
 
@@ -40,6 +42,58 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
 
     @defer.inlineCallbacks
+    def test_does_not_cache_exceptions(self):
+        """Checks that, if the callback throws an exception, it is called again
+        for the next request.
+        """
+        called = [False]
+
+        def cb():
+            if called[0]:
+                # return a valid result the second time
+                return defer.succeed(self.mock_http_response)
+
+            called[0] = True
+            raise Exception("boo")
+
+        with LoggingContext("test") as test_context:
+            try:
+                yield self.cache.fetch_or_execute(self.mock_key, cb)
+            except Exception as e:
+                self.assertEqual(e.message, "boo")
+            self.assertIs(LoggingContext.current_context(), test_context)
+
+            res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+            self.assertEqual(res, self.mock_http_response)
+            self.assertIs(LoggingContext.current_context(), test_context)
+
+    @defer.inlineCallbacks
+    def test_does_not_cache_failures(self):
+        """Checks that, if the callback returns a failure, it is called again
+        for the next request.
+        """
+        called = [False]
+
+        def cb():
+            if called[0]:
+                # return a valid result the second time
+                return defer.succeed(self.mock_http_response)
+
+            called[0] = True
+            return defer.fail(Exception("boo"))
+
+        with LoggingContext("test") as test_context:
+            try:
+                yield self.cache.fetch_or_execute(self.mock_key, cb)
+            except Exception as e:
+                self.assertEqual(e.message, "boo")
+            self.assertIs(LoggingContext.current_context(), test_context)
+
+            res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+            self.assertEqual(res, self.mock_http_response)
+            self.assertIs(LoggingContext.current_context(), test_context)
+
+    @defer.inlineCallbacks
     def test_cleans_up(self):
         cb = Mock(
             return_value=defer.succeed(self.mock_http_response)