diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index fceca2edeb..00b1b3066e 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,37 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
-from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
-
-def get_transaction_key(request):
- """A helper function which returns a transaction key that can be used
- with TransactionCache for idempotent requests.
-
- Idempotency is based on the returned key being the same for separate
- requests to the same endpoint. The key is formed from the HTTP request
- path and the access_token for the requesting user.
-
- Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
- Returns:
- str: A transaction key
- """
- token = get_access_token_from_request(request)
- return request.path + "/" + token
-
-
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
- def __init__(self, clock):
- self.clock = clock
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = self.hs.get_auth()
+ self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@@ -55,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
+ def _get_transaction_key(self, request):
+ """A helper function which returns a transaction key that can be used
+ with TransactionCache for idempotent requests.
+
+ Idempotency is based on the returned key being the same for separate
+ requests to the same endpoint. The key is formed from the HTTP request
+ path and the access_token for the requesting user.
+
+ Args:
+ request (twisted.web.http.Request): The incoming request. Must
+ contain an access_token.
+ Returns:
+ str: A transaction key
+ """
+ token = self.auth.get_access_token_from_request(request)
+ return request.path + "/" + token
+
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -63,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
- get_transaction_key(request), fn, *args, **kwargs
+ self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
@@ -80,31 +80,30 @@ class HttpTransactionCache(object):
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
- try:
- return self.transactions[txn_key][0].observe()
- except (KeyError, IndexError):
- pass # execute the function instead.
-
- deferred = fn(*args, **kwargs)
-
- # if the request fails with a Twisted failure, remove it
- # from the transaction map. This is done to ensure that we don't
- # cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
- self.transactions.pop(txn_key, None)
- return err
- deferred.addErrback(remove_from_map)
-
- # We don't add any other errbacks to the raw deferred, so we ask
- # ObservableDeferred to swallow the error. This is fine as the error will
- # still be reported to the observers.
- observable = ObservableDeferred(deferred, consumeErrors=True)
- self.transactions[txn_key] = (observable, self.clock.time_msec())
- return observable.observe()
+ if txn_key in self.transactions:
+ observable = self.transactions[txn_key][0]
+ else:
+ # execute the function instead.
+ deferred = run_in_background(fn, *args, **kwargs)
+
+ observable = ObservableDeferred(deferred)
+ self.transactions[txn_key] = (observable, self.clock.time_msec())
+
+ # if the request fails with an exception, remove it
+ # from the transaction map. This is done to ensure that we don't
+ # cache transient errors like rate-limiting errors, etc.
+ def remove_from_map(err):
+ self.transactions.pop(txn_key, None)
+ # we deliberately do not propagate the error any further, as we
+ # expect the observers to have reported it.
+
+ deferred.addErrback(remove_from_map)
+
+ return make_deferred_yieldable(observable.observe())
def _cleanup(self):
now = self.clock.time_msec()
- for key in self.transactions.keys():
+ for key in list(self.transactions):
ts = self.transactions[key][1]
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
del self.transactions[key]
|