diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 8d69e12d36..351170edbc 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -41,12 +41,19 @@ def get_transaction_key(request):
return request.path + "/" + token
+CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
+
+
class HttpTransactionCache(object):
- def __init__(self):
+ def __init__(self, clock):
+ self.clock = clock
self.transactions = {
- # $txn_key: ObservableDeferred<(res_code, res_json_body)>
+ # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
+ # Try to clean entries every 30 mins. This means entries will exist
+ # for at *LEAST* 30 mins, and at *MOST* 60 mins.
+ self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
@@ -74,11 +81,18 @@ class HttpTransactionCache(object):
Deferred which resolves to a tuple of (response_code, response_dict).
"""
try:
- return self.transactions[txn_key].observe()
- except KeyError:
+ return self.transactions[txn_key][0].observe()
+ except (KeyError, IndexError):
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
- self.transactions[txn_key] = observable
+ self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()
+
+ def _cleanup(self):
+ now = self.clock.time_msec()
+ for key in self.transactions.keys():
+ ts = self.transactions[key][1]
+ if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
+ del self.transactions[key]
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 07ff5b218c..c7aa0bbf59 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -60,4 +60,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth()
- self.txns = HttpTransactionCache()
+ self.txns = HttpTransactionCache(hs.get_clock())
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 2187350d42..ac660669f3 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache()
+ self.txns = HttpTransactionCache(hs.get_clock())
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 2b3f0bef3c..c05b9450be 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -34,7 +34,7 @@ class Clock(object):
"""A small utility that obtains current time-of-day so that time may be
mocked during unit-tests.
- TODO(paul): Also move the sleep() functionallity into it
+ TODO(paul): Also move the sleep() functionality into it
"""
def time(self):
@@ -46,6 +46,14 @@ class Clock(object):
return int(self.time() * 1000)
def looping_call(self, f, msec):
+ """Call a function repeatedly.
+
+ Waits `msec` initially before calling `f` for the first time.
+
+ Args:
+ f(function): The function to call repeatedly.
+ msec(float): How long to wait between calls in milliseconds.
+ """
l = task.LoopingCall(f)
l.start(msec / 1000.0, now=False)
return l
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
new file mode 100644
index 0000000000..d7cea30260
--- /dev/null
+++ b/tests/rest/client/test_transactions.py
@@ -0,0 +1,69 @@
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
+from twisted.internet import defer
+from mock import Mock, call
+from tests import unittest
+from tests.utils import MockClock
+
+
+class HttpTransactionCacheTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.clock = MockClock()
+ self.cache = HttpTransactionCache(self.clock)
+
+ self.mock_http_response = (200, "GOOD JOB!")
+ self.mock_key = "foo"
+
+ @defer.inlineCallbacks
+ def test_executes_given_function(self):
+ cb = Mock(
+ return_value=defer.succeed(self.mock_http_response)
+ )
+ res = yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "some_arg", keyword="arg"
+ )
+ cb.assert_called_once_with("some_arg", keyword="arg")
+ self.assertEqual(res, self.mock_http_response)
+
+ @defer.inlineCallbacks
+ def test_deduplicates_based_on_key(self):
+ cb = Mock(
+ return_value=defer.succeed(self.mock_http_response)
+ )
+ for i in range(3): # invoke multiple times
+ res = yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+ )
+ self.assertEqual(res, self.mock_http_response)
+ # expect only a single call to do the work
+ cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
+
+ @defer.inlineCallbacks
+ def test_cleans_up(self):
+ cb = Mock(
+ return_value=defer.succeed(self.mock_http_response)
+ )
+ yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "an arg"
+ )
+ # should NOT have cleaned up yet
+ self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
+
+ yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "an arg"
+ )
+ # still using cache
+ cb.assert_called_once_with("an arg")
+
+ self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
+
+ yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "an arg"
+ )
+ # no longer using cache
+ self.assertEqual(cb.call_count, 2)
+ self.assertEqual(
+ cb.call_args_list,
+ [call("an arg",), call("an arg",)]
+ )
|