summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/federation/send_queue.py63
-rw-r--r--synapse/federation/transaction_queue.py32
-rw-r--r--synapse/federation/transport/server.py3
-rw-r--r--synapse/federation/units.py1
-rw-r--r--synapse/util/caches/descriptors.py131
6 files changed, 71 insertions, 163 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e501251b6e..657935d1ac 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -207,10 +207,6 @@ class FederationServer(FederationBase):
                     edu.content
                 )
 
-        pdu_failures = getattr(transaction, "pdu_failures", [])
-        for fail in pdu_failures:
-            logger.info("Got failure %r", fail)
-
         response = {
             "pdus": pdu_results,
         }
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5157c3860d..0bb468385d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
 
         self.edus = SortedDict()  # stream position -> Edu
 
-        self.failures = SortedDict()  # stream position -> (destination, Failure)
-
         self.device_messages = SortedDict()  # stream position -> destination
 
         self.pos = 1
@@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
 
         for queue_name in [
             "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
-            "edus", "failures", "device_messages", "pos_time",
+            "edus", "device_messages", "pos_time",
         ]:
             register(queue_name, getattr(self, queue_name))
 
@@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
             for key in keys[:i]:
                 del self.edus[key]
 
-            # Delete things out of failure map
-            keys = self.failures.keys()
-            i = self.failures.bisect_left(position_to_delete)
-            for key in keys[:i]:
-                del self.failures[key]
-
             # Delete things out of device map
             keys = self.device_messages.keys()
             i = self.device_messages.bisect_left(position_to_delete)
@@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
 
         self.notifier.on_new_replication_data()
 
-    def send_failure(self, failure, destination):
-        """As per TransactionQueue"""
-        pos = self._next_pos()
-
-        self.failures[pos] = (destination, str(failure))
-        self.notifier.on_new_replication_data()
-
     def send_device_messages(self, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
@@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
         for (pos, edu) in edus:
             rows.append((pos, EduRow(edu)))
 
-        # Fetch changed failures
-        i = self.failures.bisect_right(from_token)
-        j = self.failures.bisect_right(to_token) + 1
-        failures = self.failures.items()[i:j]
-
-        for (pos, (destination, failure)) in failures:
-            rows.append((pos, FailureRow(
-                destination=destination,
-                failure=failure,
-            )))
-
         # Fetch changed device messages
         i = self.device_messages.bisect_right(from_token)
         j = self.device_messages.bisect_right(to_token) + 1
@@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
-    "destination",  # str
-    "failure",
-))):
-    """Streams failures to a remote server. Failures are issued when there was
-    something wrong with a transaction the remote sent us, e.g. it included
-    an event that was invalid.
-    """
-
-    TypeId = "f"
-
-    @staticmethod
-    def from_data(data):
-        return FailureRow(
-            destination=data["destination"],
-            failure=data["failure"],
-        )
-
-    def to_data(self):
-        return {
-            "destination": self.destination,
-            "failure": self.failure,
-        }
-
-    def add_to_buffer(self, buff):
-        buff.failures.setdefault(self.destination, []).append(self.failure)
-
-
 class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
     "destination",  # str
 ))):
@@ -471,7 +417,6 @@ TypeToRow = {
         PresenceRow,
         KeyedEduRow,
         EduRow,
-        FailureRow,
         DeviceRow,
     )
 }
@@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
     "presence",  # list(UserPresenceState)
     "keyed_edus",  # dict of destination -> { key -> Edu }
     "edus",  # dict of destination -> [Edu]
-    "failures",  # dict of destination -> [failures]
     "device_destinations",  # set of destinations
 ))
 
@@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
         presence=[],
         keyed_edus={},
         edus={},
-        failures={},
         device_destinations=set(),
     )
 
@@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
                 edu.destination, edu.edu_type, edu.content, key=None,
             )
 
-    for destination, failure_list in iteritems(buff.failures):
-        for failure in failure_list:
-            transaction_queue.send_failure(destination, failure)
-
     for destination in buff.device_destinations:
         transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6996d6b695..78f9d40a3a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -116,9 +116,6 @@ class TransactionQueue(object):
             ),
         )
 
-        # destination -> list of tuple(failure, deferred)
-        self.pending_failures_by_dest = {}
-
         # destination -> stream_id of last successfully sent to-device message.
         # NB: may be a long or an int.
         self.last_device_stream_id_by_dest = {}
@@ -382,19 +379,6 @@ class TransactionQueue(object):
 
         self._attempt_new_transaction(destination)
 
-    def send_failure(self, failure, destination):
-        if destination == self.server_name or destination == "localhost":
-            return
-
-        if not self.can_send_to(destination):
-            return
-
-        self.pending_failures_by_dest.setdefault(
-            destination, []
-        ).append(failure)
-
-        self._attempt_new_transaction(destination)
-
     def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
             return
@@ -469,7 +453,6 @@ class TransactionQueue(object):
                 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                 pending_edus = self.pending_edus_by_dest.pop(destination, [])
                 pending_presence = self.pending_presence_by_dest.pop(destination, {})
-                pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
                 pending_edus.extend(
                     self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@@ -497,7 +480,7 @@ class TransactionQueue(object):
                     logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
                                  destination, len(pending_pdus))
 
-                if not pending_pdus and not pending_edus and not pending_failures:
+                if not pending_pdus and not pending_edus:
                     logger.debug("TX [%s] Nothing to send", destination)
                     self.last_device_stream_id_by_dest[destination] = (
                         device_stream_id
@@ -507,7 +490,7 @@ class TransactionQueue(object):
                 # END CRITICAL SECTION
 
                 success = yield self._send_new_transaction(
-                    destination, pending_pdus, pending_edus, pending_failures,
+                    destination, pending_pdus, pending_edus,
                 )
                 if success:
                     sent_transactions_counter.inc()
@@ -584,14 +567,12 @@ class TransactionQueue(object):
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
-    def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures):
+    def _send_new_transaction(self, destination, pending_pdus, pending_edus):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
         pdus = [x[0] for x in pending_pdus]
         edus = pending_edus
-        failures = [x.get_dict() for x in pending_failures]
 
         success = True
 
@@ -601,11 +582,10 @@ class TransactionQueue(object):
 
         logger.debug(
             "TX [%s] {%s} Attempting new transaction"
-            " (pdus: %d, edus: %d, failures: %d)",
+            " (pdus: %d, edus: %d)",
             destination, txn_id,
             len(pdus),
             len(edus),
-            len(failures)
         )
 
         logger.debug("TX [%s] Persisting transaction...", destination)
@@ -617,7 +597,6 @@ class TransactionQueue(object):
             destination=destination,
             pdus=pdus,
             edus=edus,
-            pdu_failures=failures,
         )
 
         self._next_txn_id += 1
@@ -627,12 +606,11 @@ class TransactionQueue(object):
         logger.debug("TX [%s] Persisted transaction", destination)
         logger.info(
             "TX [%s] {%s} Sending transaction [%s],"
-            " (PDUs: %d, EDUs: %d, failures: %d)",
+            " (PDUs: %d, EDUs: %d)",
             destination, txn_id,
             transaction.transaction_id,
             len(pdus),
             len(edus),
-            len(failures),
         )
 
         # Actually send the transaction
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8574898f0c..3b5ea9515a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
             )
 
             logger.info(
-                "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+                "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
                 transaction_id, origin,
                 len(transaction_data.get("pdus", [])),
                 len(transaction_data.get("edus", [])),
-                len(transaction_data.get("failures", [])),
             )
 
             # We should ideally be getting this from the security layer.
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index bb1b3b13f7..c5ab14314e 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
         "previous_ids",
         "pdus",
         "edus",
-        "pdu_failures",
     ]
 
     internal_keys = [
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            # If we're passed a cache_context then we'll want to call its invalidate()
-            # whenever we are invalidated
+            # If we're passed a cache_context then we'll want to call its
+            # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            # cached is a dict arg -> deferred, where deferred results in a
-            # 2-tuple (`arg`, `result`)
             results = {}
-            cached_defers = {}
-            missing = []
+
+            def update_results_dict(res, arg):
+                results[arg] = res
+
+            # list of deferreds to wait for
+            cached_defers = []
+
+            missing = set()
 
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
-                def cache_get(arg):
-                    return cache.get(arg, callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    return arg
             else:
-                key = list(keyargs)
+                keylist = list(keyargs)
 
-                def cache_get(arg):
-                    key[self.list_pos] = arg
-                    return cache.get(tuple(key), callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    keylist[self.list_pos] = arg
+                    return tuple(keylist)
 
             for arg in list_args:
                 try:
-                    res = cache_get(arg)
-
+                    res = cache.get(arg_to_cache_key(arg),
+                                    callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
                         res = res.observe()
-                        res.addCallback(lambda r, arg: (arg, r), arg)
-                        cached_defers[arg] = res
+                        res.addCallback(update_results_dict, arg)
+                        cached_defers.append(res)
                     else:
                         results[arg] = res.get_result()
                 except KeyError:
-                    missing.append(arg)
+                    missing.add(arg)
 
             if missing:
+                # we need an observable deferred for each entry in the list,
+                # which we put in the cache. Each deferred resolves with the
+                # relevant result for that key.
+                deferreds_map = {}
+                for arg in missing:
+                    deferred = defer.Deferred()
+                    deferreds_map[arg] = deferred
+                    key = arg_to_cache_key(arg)
+                    observable = ObservableDeferred(deferred)
+                    cache.set(key, observable, callback=invalidate_callback)
+
+                def complete_all(res):
+                    # the wrapped function has completed. It returns a
+                    # a dict. We can now resolve the observable deferreds in
+                    # the cache and update our own result map.
+                    for e in missing:
+                        val = res.get(e, None)
+                        deferreds_map[e].callback(val)
+                        results[e] = val
+
+                def errback(f):
+                    # the wrapped function has failed. Invalidate any cache
+                    # entries we're supposed to be populating, and fail
+                    # their deferreds.
+                    for e in missing:
+                        key = arg_to_cache_key(e)
+                        cache.invalidate(key)
+                        deferreds_map[e].errback(f)
+
+                    # return the failure, to propagate to our caller.
+                    return f
+
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = list(missing)
 
-                ret_d = defer.maybeDeferred(
+                cached_defers.append(defer.maybeDeferred(
                     logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
-                )
-
-                ret_d = ObservableDeferred(ret_d)
-
-                # We need to create deferreds for each arg in the list so that
-                # we can insert the new deferred into the cache.
-                for arg in missing:
-                    observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
-                    observer = ObservableDeferred(observer)
-
-                    if num_args == 1:
-                        cache.set(
-                            arg, observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, arg)
-                    else:
-                        key = list(keyargs)
-                        key[self.list_pos] = arg
-                        cache.set(
-                            tuple(key), observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, tuple(key))
-
-                    res = observer.observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-
-                    cached_defers[arg] = res
+                ).addCallbacks(complete_all, errback))
 
             if cached_defers:
-                def update_results_dict(res):
-                    results.update(res)
-                    return results
-
-                return logcontext.make_deferred_yieldable(defer.gatherResults(
-                    list(cached_defers.values()),
+                d = defer.gatherResults(
+                    cached_defers,
                     consumeErrors=True,
-                ).addCallback(update_results_dict).addErrback(
+                ).addCallbacks(
+                    lambda _: results,
                     unwrapFirstError
-                ))
+                )
+                return logcontext.make_deferred_yieldable(d)
             else:
                 return results
 
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
     cache.
 
     Args:
-        cache (Cache): The underlying cache to use.
+        cached_method_name (str): The name of the single-item lookup method.
+            This is only used to find the cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
         num_args (int): Number of arguments to use as the key in the cache