diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 0bf97e37ee..d70467dcd6 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore, Table, JoinHelper
+from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
from collections import namedtuple
@@ -42,7 +43,7 @@ class PduStore(SQLBaseStore):
PduTuple: If the pdu does not exist in the database, returns None
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_pdu_tuple, pdu_id, origin
)
@@ -94,7 +95,7 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_current_state_for_context,
context
)
@@ -142,7 +143,7 @@ class PduStore(SQLBaseStore):
pdu_origin (str)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin
)
@@ -151,7 +152,7 @@ class PduStore(SQLBaseStore):
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_all_pdus_from_context, context,
)
@@ -178,7 +179,7 @@ class PduStore(SQLBaseStore):
Return:
list: A list of PduTuples
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_backfill, context, pdu_list, limit
)
@@ -239,7 +240,7 @@ class PduStore(SQLBaseStore):
txn
context (str)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_min_depth_for_context, context
)
@@ -308,8 +309,8 @@ class PduStore(SQLBaseStore):
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
- """Get a list of Pdus that we haven't backfilled beyond yet (and haven't
- seen). This list is used when we want to backfill backwards and is the
+ """Get a list of Pdus that we haven't backfilled beyond yet (and havent
+ seen). This list is used when we want to backfill backwards and is the
list we send to the remote server.
Args:
@@ -345,7 +346,7 @@ class PduStore(SQLBaseStore):
bool
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
@@ -498,7 +499,7 @@ class StatePduStore(SQLBaseStore):
)
def get_unresolved_state_tree(self, new_state_pdu):
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_unresolved_state_tree, new_state_pdu
)
@@ -516,7 +517,7 @@ class StatePduStore(SQLBaseStore):
if not current:
logger.debug("get_unresolved_state_tree No current state.")
- return return_value
+ return (return_value, None)
return_value.current_branch.append(current)
@@ -524,17 +525,20 @@ class StatePduStore(SQLBaseStore):
txn, new_pdu, current
)
+ missing_branch = None
for branch, prev_state, state in enum_branches:
if state:
return_value[branch].append(state)
else:
+ # We don't have prev_state :(
+ missing_branch = branch
break
- return return_value
+ return (return_value, missing_branch)
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
@@ -573,7 +577,7 @@ class StatePduStore(SQLBaseStore):
PduEntry
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_current_state_pdu, context, pdu_type, state_key
)
@@ -622,53 +626,6 @@ class StatePduStore(SQLBaseStore):
return result
- def get_next_missing_pdu(self, new_pdu):
- """When we get a new state pdu we need to check whether we need to do
- any conflict resolution, if we do then we need to check if we need
- to go back and request some more state pdus that we haven't seen yet.
-
- Args:
- txn
- new_pdu
-
- Returns:
- PduIdTuple: A pdu that we are missing, or None if we have all the
- pdus required to do the conflict resolution.
- """
- return self._db_pool.runInteraction(
- self._get_next_missing_pdu, new_pdu
- )
-
- def _get_next_missing_pdu(self, txn, new_pdu):
- logger.debug(
- "get_next_missing_pdu %s %s",
- new_pdu.pdu_id, new_pdu.origin
- )
-
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- if (not current or not current.prev_state_id
- or not current.prev_state_origin):
- return None
-
- # Oh look, it's a straight clobber, so wooooo almost no-op.
- if (new_pdu.prev_state_id == current.pdu_id
- and new_pdu.prev_state_origin == current.origin):
- return None
-
- enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
- for branch, prev_state, state in enum_branches:
- if not state:
- return PduIdTuple(
- prev_state.prev_state_id,
- prev_state.prev_state_origin
- )
-
- return None
-
def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it.
@@ -679,7 +636,7 @@ class StatePduStore(SQLBaseStore):
Returns:
bool: True if the new_pdu clobbered the current state, False if not
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._handle_new_state, new_pdu
)
@@ -752,24 +709,11 @@ class StatePduStore(SQLBaseStore):
return is_current
- @classmethod
@log_function
- def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
+ def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
branch_a = pdu_a
branch_b = pdu_b
- get_query = (
- "SELECT %(fields)s FROM %(pdus)s as p "
- "LEFT JOIN %(state)s as s "
- "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
- "WHERE p.pdu_id = ? AND p.origin = ? "
- ) % {
- "fields": _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s"),
- "pdus": PdusTable.table_name,
- "state": StatePdusTable.table_name,
- }
-
while True:
if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin):
@@ -801,13 +745,12 @@ class StatePduStore(SQLBaseStore):
branch_a.prev_state_origin
)
- logger.debug("getting branch_a prev %s", pdu_tuple)
- txn.execute(get_query, pdu_tuple)
-
prev_branch = branch_a
- res = txn.fetchone()
- branch_a = PduEntry(*res) if res else None
+ logger.debug("getting branch_a prev %s", pdu_tuple)
+ branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
+ if branch_a:
+ branch_a = Pdu.from_pdu_tuple(branch_a)
logger.debug("branch_a=%s", branch_a)
@@ -820,14 +763,13 @@ class StatePduStore(SQLBaseStore):
branch_b.prev_state_id,
branch_b.prev_state_origin
)
- txn.execute(get_query, pdu_tuple)
-
- logger.debug("getting branch_b prev %s", pdu_tuple)
prev_branch = branch_b
- res = txn.fetchone()
- branch_b = PduEntry(*res) if res else None
+ logger.debug("getting branch_b prev %s", pdu_tuple)
+ branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
+ if branch_b:
+ branch_b = Pdu.from_pdu_tuple(branch_b)
logger.debug("branch_b=%s", branch_b)
|