summary refs log tree commit diff
path: root/synapse/storage/pdu.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/pdu.py')
-rw-r--r--synapse/storage/pdu.py92
1 files changed, 17 insertions, 75 deletions
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 0bf97e37ee..3c859fdeac 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore, Table, JoinHelper
 
+from synapse.federation.units import Pdu
 from synapse.util.logutils import log_function
 
 from collections import namedtuple
@@ -308,8 +309,8 @@ class PduStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_oldest_pdus_in_context(self, context):
-        """Get a list of Pdus that we haven't backfilled beyond yet (and haven't    
-        seen). This list is used when we want to backfill backwards and is the 
+        """Get a list of Pdus that we haven't backfilled beyond yet (and havent
+        seen). This list is used when we want to backfill backwards and is the
         list we send to the remote server.
 
         Args:
@@ -516,7 +517,7 @@ class StatePduStore(SQLBaseStore):
 
         if not current:
             logger.debug("get_unresolved_state_tree No current state.")
-            return return_value
+            return (return_value, None)
 
         return_value.current_branch.append(current)
 
@@ -524,13 +525,16 @@ class StatePduStore(SQLBaseStore):
             txn, new_pdu, current
         )
 
+        missing_branch = None
         for branch, prev_state, state in enum_branches:
             if state:
                 return_value[branch].append(state)
             else:
+                # We don't have prev_state :(
+                missing_branch = branch
                 break
 
-        return return_value
+        return (return_value, missing_branch)
 
     def update_current_state(self, pdu_id, origin, context, pdu_type,
                              state_key):
@@ -622,53 +626,6 @@ class StatePduStore(SQLBaseStore):
 
         return result
 
-    def get_next_missing_pdu(self, new_pdu):
-        """When we get a new state pdu we need to check whether we need to do
-        any conflict resolution, if we do then we need to check if we need
-        to go back and request some more state pdus that we haven't seen yet.
-
-        Args:
-            txn
-            new_pdu
-
-        Returns:
-            PduIdTuple: A pdu that we are missing, or None if we have all the
-                pdus required to do the conflict resolution.
-        """
-        return self._db_pool.runInteraction(
-            self._get_next_missing_pdu, new_pdu
-        )
-
-    def _get_next_missing_pdu(self, txn, new_pdu):
-        logger.debug(
-            "get_next_missing_pdu %s %s",
-            new_pdu.pdu_id, new_pdu.origin
-        )
-
-        current = self._get_current_interaction(
-            txn,
-            new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
-        )
-
-        if (not current or not current.prev_state_id
-                or not current.prev_state_origin):
-            return None
-
-        # Oh look, it's a straight clobber, so wooooo almost no-op.
-        if (new_pdu.prev_state_id == current.pdu_id
-                and new_pdu.prev_state_origin == current.origin):
-            return None
-
-        enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
-        for branch, prev_state, state in enum_branches:
-            if not state:
-                return PduIdTuple(
-                    prev_state.prev_state_id,
-                    prev_state.prev_state_origin
-                )
-
-        return None
-
     def handle_new_state(self, new_pdu):
         """Actually perform conflict resolution on the new_pdu on the
         assumption we have all the pdus required to perform it.
@@ -752,24 +709,11 @@ class StatePduStore(SQLBaseStore):
 
         return is_current
 
-    @classmethod
     @log_function
-    def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
+    def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
         branch_a = pdu_a
         branch_b = pdu_b
 
-        get_query = (
-            "SELECT %(fields)s FROM %(pdus)s as p "
-            "LEFT JOIN %(state)s as s "
-            "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
-            "WHERE p.pdu_id = ? AND p.origin = ? "
-        ) % {
-            "fields": _pdu_state_joiner.get_fields(
-                PdusTable="p", StatePdusTable="s"),
-            "pdus": PdusTable.table_name,
-            "state": StatePdusTable.table_name,
-        }
-
         while True:
             if (branch_a.pdu_id == branch_b.pdu_id
                     and branch_a.origin == branch_b.origin):
@@ -801,13 +745,12 @@ class StatePduStore(SQLBaseStore):
                     branch_a.prev_state_origin
                 )
 
-                logger.debug("getting branch_a prev %s", pdu_tuple)
-                txn.execute(get_query, pdu_tuple)
-
                 prev_branch = branch_a
 
-                res = txn.fetchone()
-                branch_a = PduEntry(*res) if res else None
+                logger.debug("getting branch_a prev %s", pdu_tuple)
+                branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
+                if branch_a:
+                    branch_a = Pdu.from_pdu_tuple(branch_a)
 
                 logger.debug("branch_a=%s", branch_a)
 
@@ -820,14 +763,13 @@ class StatePduStore(SQLBaseStore):
                     branch_b.prev_state_id,
                     branch_b.prev_state_origin
                 )
-                txn.execute(get_query, pdu_tuple)
-
-                logger.debug("getting branch_b prev %s", pdu_tuple)
 
                 prev_branch = branch_b
 
-                res = txn.fetchone()
-                branch_b = PduEntry(*res) if res else None
+                logger.debug("getting branch_b prev %s", pdu_tuple)
+                branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
+                if branch_b:
+                    branch_b = Pdu.from_pdu_tuple(branch_b)
 
                 logger.debug("branch_b=%s", branch_b)