diff --git a/changelog.d/3384.misc b/changelog.d/3384.misc
new file mode 100644
index 0000000000..5d56c876d9
--- /dev/null
+++ b/changelog.d/3384.misc
@@ -0,0 +1 @@
+Rewrite cache list decorator
diff --git a/changelog.d/3628.misc b/changelog.d/3628.misc
new file mode 100644
index 0000000000..1aebefbe18
--- /dev/null
+++ b/changelog.d/3628.misc
@@ -0,0 +1 @@
+Remove unused field "pdu_failures" from transactions.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e501251b6e..657935d1ac 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -207,10 +207,6 @@ class FederationServer(FederationBase):
edu.content
)
- pdu_failures = getattr(transaction, "pdu_failures", [])
- for fail in pdu_failures:
- logger.info("Got failure %r", fail)
-
response = {
"pdus": pdu_results,
}
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5157c3860d..0bb468385d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu
- self.failures = SortedDict() # stream position -> (destination, Failure)
-
self.device_messages = SortedDict() # stream position -> destination
self.pos = 1
@@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "failures", "device_messages", "pos_time",
+ "edus", "device_messages", "pos_time",
]:
register(queue_name, getattr(self, queue_name))
@@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]:
del self.edus[key]
- # Delete things out of failure map
- keys = self.failures.keys()
- i = self.failures.bisect_left(position_to_delete)
- for key in keys[:i]:
- del self.failures[key]
-
# Delete things out of device map
keys = self.device_messages.keys()
i = self.device_messages.bisect_left(position_to_delete)
@@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
- def send_failure(self, failure, destination):
- """As per TransactionQueue"""
- pos = self._next_pos()
-
- self.failures[pos] = (destination, str(failure))
- self.notifier.on_new_replication_data()
-
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
@@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus:
rows.append((pos, EduRow(edu)))
- # Fetch changed failures
- i = self.failures.bisect_right(from_token)
- j = self.failures.bisect_right(to_token) + 1
- failures = self.failures.items()[i:j]
-
- for (pos, (destination, failure)) in failures:
- rows.append((pos, FailureRow(
- destination=destination,
- failure=failure,
- )))
-
# Fetch changed device messages
i = self.device_messages.bisect_right(from_token)
j = self.device_messages.bisect_right(to_token) + 1
@@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
- "destination", # str
- "failure",
-))):
- """Streams failures to a remote server. Failures are issued when there was
- something wrong with a transaction the remote sent us, e.g. it included
- an event that was invalid.
- """
-
- TypeId = "f"
-
- @staticmethod
- def from_data(data):
- return FailureRow(
- destination=data["destination"],
- failure=data["failure"],
- )
-
- def to_data(self):
- return {
- "destination": self.destination,
- "failure": self.failure,
- }
-
- def add_to_buffer(self, buff):
- buff.failures.setdefault(self.destination, []).append(self.failure)
-
-
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str
))):
@@ -471,7 +417,6 @@ TypeToRow = {
PresenceRow,
KeyedEduRow,
EduRow,
- FailureRow,
DeviceRow,
)
}
@@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
- "failures", # dict of destination -> [failures]
"device_destinations", # set of destinations
))
@@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
presence=[],
keyed_edus={},
edus={},
- failures={},
device_destinations=set(),
)
@@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
edu.destination, edu.edu_type, edu.content, key=None,
)
- for destination, failure_list in iteritems(buff.failures):
- for failure in failure_list:
- transaction_queue.send_failure(destination, failure)
-
for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6996d6b695..78f9d40a3a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -116,9 +116,6 @@ class TransactionQueue(object):
),
)
- # destination -> list of tuple(failure, deferred)
- self.pending_failures_by_dest = {}
-
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {}
@@ -382,19 +379,6 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination)
- def send_failure(self, failure, destination):
- if destination == self.server_name or destination == "localhost":
- return
-
- if not self.can_send_to(destination):
- return
-
- self.pending_failures_by_dest.setdefault(
- destination, []
- ).append(failure)
-
- self._attempt_new_transaction(destination)
-
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -469,7 +453,6 @@ class TransactionQueue(object):
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@@ -497,7 +480,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
- if not pending_pdus and not pending_edus and not pending_failures:
+ if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
@@ -507,7 +490,7 @@ class TransactionQueue(object):
# END CRITICAL SECTION
success = yield self._send_new_transaction(
- destination, pending_pdus, pending_edus, pending_failures,
+ destination, pending_pdus, pending_edus,
)
if success:
sent_transactions_counter.inc()
@@ -584,14 +567,12 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
- def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures):
+ def _send_new_transaction(self, destination, pending_pdus, pending_edus):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
- failures = [x.get_dict() for x in pending_failures]
success = True
@@ -601,11 +582,10 @@ class TransactionQueue(object):
logger.debug(
"TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
+ " (pdus: %d, edus: %d)",
destination, txn_id,
len(pdus),
len(edus),
- len(failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
@@ -617,7 +597,6 @@ class TransactionQueue(object):
destination=destination,
pdus=pdus,
edus=edus,
- pdu_failures=failures,
)
self._next_txn_id += 1
@@ -627,12 +606,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
+ " (PDUs: %d, EDUs: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
- len(failures),
)
# Actually send the transaction
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8574898f0c..3b5ea9515a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
)
logger.info(
- "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+ "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
- len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index bb1b3b13f7..c5ab14314e 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "pdu_failures",
]
internal_keys = [
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
- # If we're passed a cache_context then we'll want to call its invalidate()
- # whenever we are invalidated
+ # If we're passed a cache_context then we'll want to call its
+ # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- # cached is a dict arg -> deferred, where deferred results in a
- # 2-tuple (`arg`, `result`)
results = {}
- cached_defers = {}
- missing = []
+
+ def update_results_dict(res, arg):
+ results[arg] = res
+
+ # list of deferreds to wait for
+ cached_defers = []
+
+ missing = set()
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
- def cache_get(arg):
- return cache.get(arg, callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ return arg
else:
- key = list(keyargs)
+ keylist = list(keyargs)
- def cache_get(arg):
- key[self.list_pos] = arg
- return cache.get(tuple(key), callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ keylist[self.list_pos] = arg
+ return tuple(keylist)
for arg in list_args:
try:
- res = cache_get(arg)
-
+ res = cache.get(arg_to_cache_key(arg),
+ callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached_defers[arg] = res
+ res.addCallback(update_results_dict, arg)
+ cached_defers.append(res)
else:
results[arg] = res.get_result()
except KeyError:
- missing.append(arg)
+ missing.add(arg)
if missing:
+ # we need an observable deferred for each entry in the list,
+ # which we put in the cache. Each deferred resolves with the
+ # relevant result for that key.
+ deferreds_map = {}
+ for arg in missing:
+ deferred = defer.Deferred()
+ deferreds_map[arg] = deferred
+ key = arg_to_cache_key(arg)
+ observable = ObservableDeferred(deferred)
+ cache.set(key, observable, callback=invalidate_callback)
+
+ def complete_all(res):
+ # the wrapped function has completed. It returns a
+ # a dict. We can now resolve the observable deferreds in
+ # the cache and update our own result map.
+ for e in missing:
+ val = res.get(e, None)
+ deferreds_map[e].callback(val)
+ results[e] = val
+
+ def errback(f):
+ # the wrapped function has failed. Invalidate any cache
+ # entries we're supposed to be populating, and fail
+ # their deferreds.
+ for e in missing:
+ key = arg_to_cache_key(e)
+ cache.invalidate(key)
+ deferreds_map[e].errback(f)
+
+ # return the failure, to propagate to our caller.
+ return f
+
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = list(missing)
- ret_d = defer.maybeDeferred(
+ cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
- )
-
- ret_d = ObservableDeferred(ret_d)
-
- # We need to create deferreds for each arg in the list so that
- # we can insert the new deferred into the cache.
- for arg in missing:
- observer = ret_d.observe()
- observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
- observer = ObservableDeferred(observer)
-
- if num_args == 1:
- cache.set(
- arg, observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, arg)
- else:
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
-
- res = observer.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
-
- cached_defers[arg] = res
+ ).addCallbacks(complete_all, errback))
if cached_defers:
- def update_results_dict(res):
- results.update(res)
- return results
-
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- list(cached_defers.values()),
+ d = defer.gatherResults(
+ cached_defers,
consumeErrors=True,
- ).addCallback(update_results_dict).addErrback(
+ ).addCallbacks(
+ lambda _: results,
unwrapFirstError
- ))
+ )
+ return logcontext.make_deferred_yieldable(d)
else:
return results
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache.
Args:
- cache (Cache): The underlying cache to use.
+ cached_method_name (str): The name of the single-item lookup method.
+ This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b08856f763..2c263af1a3 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content,
}
],
- "pdu_failures": [],
}
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 8176a7dabd..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def test_cache(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ assert (
+ logcontext.LoggingContext.current_context().request == "c1"
+ )
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ assert (
+ logcontext.LoggingContext.current_context().request == "c1"
+ )
+ defer.returnValue(self.mock(args1, arg2))
+
+ with logcontext.LoggingContext() as c1:
+ c1.request = "c1"
+ obj = Cls()
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ d1 = obj.list_fn([10, 20], 2)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ r = yield d1
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ c1
+ )
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = {30: 'peas'}
+ r = yield obj.list_fn([20, 30], 2)
+ obj.mock.assert_called_once_with([30], 2)
+ self.assertEqual(r, {20: 'chips', 30: 'peas'})
+ obj.mock.reset_mock()
+
+ # all the values should now be cached
+ r = yield obj.fn(10, 2)
+ self.assertEqual(r, 'fish')
+ r = yield obj.fn(20, 2)
+ self.assertEqual(r, 'chips')
+ r = yield obj.fn(30, 2)
+ self.assertEqual(r, 'peas')
+ r = yield obj.list_fn([10, 20, 30], 2)
+ obj.mock.assert_not_called()
+ self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ """Make sure that invalidation callbacks are called."""
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ defer.returnValue(self.mock(args1, arg2))
+
+ obj = Cls()
+ invalidate0 = mock.Mock()
+ invalidate1 = mock.Mock()
+
+ # cache miss
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # cache hit
+ r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+ obj.mock.assert_not_called()
+ self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+ invalidate0.assert_not_called()
+ invalidate1.assert_not_called()
+
+ # now if we invalidate the keys, both invalidations should get called
+ obj.fn.invalidate((10, 2))
+ invalidate0.assert_called_once()
+ invalidate1.assert_called_once()
|