diff options
Diffstat (limited to 'synapse/storage/pdu.py')
-rw-r--r-- | synapse/storage/pdu.py | 915 |
1 files changed, 0 insertions, 915 deletions
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py deleted file mode 100644 index d70467dcd6..0000000000 --- a/synapse/storage/pdu.py +++ /dev/null @@ -1,915 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from ._base import SQLBaseStore, Table, JoinHelper - -from synapse.federation.units import Pdu -from synapse.util.logutils import log_function - -from collections import namedtuple - -import logging - -logger = logging.getLogger(__name__) - - -class PduStore(SQLBaseStore): - """A collection of queries for handling PDUs. - """ - - def get_pdu(self, pdu_id, origin): - """Given a pdu_id and origin, get a PDU. - - Args: - txn - pdu_id (str) - origin (str) - - Returns: - PduTuple: If the pdu does not exist in the database, returns None - """ - - return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin - ) - - def _get_pdu_tuple(self, txn, pdu_id, origin): - res = self._get_pdu_tuples(txn, [(pdu_id, origin)]) - return res[0] if res else None - - def _get_pdu_tuples(self, txn, pdu_id_tuples): - results = [] - for pdu_id, origin in pdu_id_tuples: - txn.execute( - PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"), - (pdu_id, origin) - ) - - edges = [ - (r.prev_pdu_id, r.prev_origin) - for r in PduEdgesTable.decode_results(txn.fetchall()) - ] - - query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - row = txn.fetchone() - if row: - results.append(PduTuple(PduEntry(*row), edges)) - - return results - - def get_current_state_for_context(self, context): - """Get a list of PDUs that represent the current state for a given - context - - Args: - context (str) - - Returns: - list: A list of PduTuples - """ - - return self.runInteraction( - self._get_current_state_for_context, - context - ) - - def _get_current_state_for_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s WHERE context = ?" - % CurrentStateTable.table_name - ) - - logger.debug("get_current_state %s, Args=%s", query, context) - txn.execute(query, (context,)) - - res = txn.fetchall() - - logger.debug("get_current_state %d results", len(res)) - - return self._get_pdu_tuples(txn, res) - - def _persist_pdu_txn(self, txn, prev_pdus, cols): - """Inserts a (non-state) PDU into the database. - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable. - """ - entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - - txn.execute(PdusTable.insert_statement(), entry) - - self._handle_prev_pdus( - txn, entry.outlier, entry.pdu_id, entry.origin, - prev_pdus, entry.context - ) - - def mark_pdu_as_processed(self, pdu_id, pdu_origin): - """Mark a received PDU as processed. - - Args: - txn - pdu_id (str) - pdu_origin (str) - """ - - return self.runInteraction( - self._mark_as_processed, pdu_id, pdu_origin - ) - - def _mark_as_processed(self, txn, pdu_id, pdu_origin): - txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name) - - def get_all_pdus_from_context(self, context): - """Get a list of all PDUs for a given context.""" - return self.runInteraction( - self._get_all_pdus_from_context, context, - ) - - def _get_all_pdus_from_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s " - "WHERE context = ?" - ) % PdusTable.table_name - - txn.execute(query, (context,)) - - return self._get_pdu_tuples(txn, txn.fetchall()) - - def get_backfill(self, context, pdu_list, limit): - """Get a list of Pdus for a given topic that occured before (and - including) the pdus in pdu_list. Return a list of max size `limit`. - - Args: - txn - context (str) - pdu_list (list) - limit (int) - - Return: - list: A list of PduTuples - """ - return self.runInteraction( - self._get_backfill, context, pdu_list, limit - ) - - def _get_backfill(self, txn, context, pdu_list, limit): - logger.debug( - "backfill: %s, %s, %s", - context, repr(pdu_list), limit - ) - - # We seed the pdu_results with the things from the pdu_list. - pdu_results = pdu_list - - front = pdu_list - - query = ( - "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s " - "WHERE context = ? AND pdu_id = ? AND origin = ? " - "LIMIT ?" - ) % { - "edges_table": PduEdgesTable.table_name, - } - - # We iterate through all pdu_ids in `front` to select their previous - # pdus. These are dumped in `new_front`. We continue until we reach the - # limit *or* new_front is empty (i.e., we've run out of things to - # select - while front and len(pdu_results) < limit: - - new_front = [] - for pdu_id, origin in front: - logger.debug( - "_backfill_interaction: i=%s, o=%s", - pdu_id, origin - ) - - txn.execute( - query, - (context, pdu_id, origin, limit - len(pdu_results)) - ) - - for row in txn.fetchall(): - logger.debug( - "_backfill_interaction: got i=%s, o=%s", - *row - ) - new_front.append(row) - - front = new_front - pdu_results += new_front - - # We also want to update the `prev_pdus` attributes before returning. - return self._get_pdu_tuples(txn, pdu_results) - - def get_min_depth_for_context(self, context): - """Get the current minimum depth for a context - - Args: - txn - context (str) - """ - return self.runInteraction( - self._get_min_depth_for_context, context - ) - - def _get_min_depth_for_context(self, txn, context): - return self._get_min_depth_interaction(txn, context) - - def _get_min_depth_interaction(self, txn, context): - txn.execute( - "SELECT min_depth FROM %s WHERE context = ?" - % ContextDepthTable.table_name, - (context,) - ) - - row = txn.fetchone() - - return row[0] if row else None - - def _update_min_depth_for_context_txn(self, txn, context, depth): - """Update the minimum `depth` of the given context, which is the line - on which we stop backfilling backwards. - - Args: - context (str) - depth (int) - """ - min_depth = self._get_min_depth_interaction(txn, context) - - do_insert = depth < min_depth if min_depth else True - - if do_insert: - txn.execute( - "INSERT OR REPLACE INTO %s (context, min_depth) " - "VALUES (?,?)" % ContextDepthTable.table_name, - (context, depth) - ) - - def _get_latest_pdus_in_context(self, txn, context): - """Get's a list of the most current pdus for a given context. This is - used when we are sending a Pdu and need to fill out the `prev_pdus` - key - - Args: - txn - context - """ - query = ( - "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " - "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " - "AND f.origin = p.origin " - "WHERE f.context = ?" - ) % { - "pdus": PdusTable.table_name, - "forward": PduForwardExtremitiesTable.table_name, - } - - logger.debug("get_prev query: %s", query) - - txn.execute( - query, - (context, ) - ) - - results = txn.fetchall() - - return [(row[0], row[1], row[2]) for row in results] - - @defer.inlineCallbacks - def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and havent - seen). This list is used when we want to backfill backwards and is the - list we send to the remote server. - - Args: - txn - context (str) - - Returns: - list: A list of PduIdTuple. - """ - results = yield self._execute( - None, - "SELECT pdu_id, origin FROM %(back)s WHERE context = ?" - % {"back": PduBackwardExtremitiesTable.table_name, }, - context - ) - - defer.returnValue([PduIdTuple(i, o) for i, o in results]) - - def is_pdu_new(self, pdu_id, origin, context, depth): - """For a given Pdu, try and figure out if it's 'new', i.e., if it's - not something we got randomly from the past, for example when we - request the current state of the room that will probably return a bunch - of pdus from before we joined. - - Args: - txn - pdu_id (str) - origin (str) - context (str) - depth (int) - - Returns: - bool - """ - - return self.runInteraction( - self._is_pdu_new, - pdu_id=pdu_id, - origin=origin, - context=context, - depth=depth - ) - - def _is_pdu_new(self, txn, pdu_id, origin, context, depth): - # If depth > min depth in back table, then we classify it as new. - # OR if there is nothing in the back table, then it kinda needs to - # be a new thing. - query = ( - "SELECT min(p.depth) FROM %(edges)s as e " - "INNER JOIN %(back)s as b " - "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin " - "INNER JOIN %(pdus)s as p " - "ON e.pdu_id = p.pdu_id AND p.origin = e.origin " - "WHERE p.context = ?" - ) % { - "pdus": PdusTable.table_name, - "edges": PduEdgesTable.table_name, - "back": PduBackwardExtremitiesTable.table_name, - } - - txn.execute(query, (context,)) - - min_depth, = txn.fetchone() - - if not min_depth or depth > int(min_depth): - logger.debug( - "is_new true: id=%s, o=%s, d=%s min_depth=%s", - pdu_id, origin, depth, min_depth - ) - return True - - # If this pdu is in the forwards table, then it also is a new one - query = ( - "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?" - ) % { - "forward": PduForwardExtremitiesTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - # Did we get anything? - if txn.fetchall(): - logger.debug( - "is_new true: id=%s, o=%s, d=%s was forward", - pdu_id, origin, depth - ) - return True - - logger.debug( - "is_new false: id=%s, o=%s, d=%s", - pdu_id, origin, depth - ) - - # FINE THEN. It's probably old. - return False - - @staticmethod - @log_function - def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus, - context): - txn.executemany( - PduEdgesTable.insert_statement(), - [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus] - ) - - # Update the extremities table if this is not an outlier. - if not outlier: - - # First, we delete the new one from the forwards extremities table. - query = ( - "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" - % PduForwardExtremitiesTable.table_name - ) - txn.executemany(query, prev_pdus) - - # We only insert as a forward extremety the new pdu if there are no - # other pdus that reference it as a prev pdu - query = ( - "INSERT INTO %(table)s (pdu_id, origin, context) " - "SELECT ?, ?, ? WHERE NOT EXISTS (" - "SELECT 1 FROM %(pdu_edges)s WHERE " - "prev_pdu_id = ? AND prev_origin = ?" - ")" - ) % { - "table": PduForwardExtremitiesTable.table_name, - "pdu_edges": PduEdgesTable.table_name - } - - logger.debug("query: %s", query) - - txn.execute(query, (pdu_id, origin, context, pdu_id, origin)) - - # Insert all the prev_pdus as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - txn.executemany( - PduBackwardExtremitiesTable.insert_statement(), - [(i, o, context) for i, o in prev_pdus] - ) - - # Also delete from the backwards extremities table all ones that - # reference pdus that we have already seen - query = ( - "DELETE FROM %(pdu_back)s WHERE EXISTS (" - "SELECT 1 FROM %(pdus)s AS pdus " - "WHERE " - "%(pdu_back)s.pdu_id = pdus.pdu_id " - "AND %(pdu_back)s.origin = pdus.origin " - "AND not pdus.outlier " - ")" - ) % { - "pdu_back": PduBackwardExtremitiesTable.table_name, - "pdus": PdusTable.table_name, - } - txn.execute(query) - - -class StatePduStore(SQLBaseStore): - """A collection of queries for handling state PDUs. - """ - - def _persist_state_txn(self, txn, prev_pdus, cols): - """Inserts a state PDU into the database - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable and StatePdusTable - """ - pdu_entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - state_entry = StatePdusTable.EntryType( - **{k: cols.get(k, None) for k in StatePdusTable.fields} - ) - - logger.debug("Inserting pdu: %s", repr(pdu_entry)) - logger.debug("Inserting state: %s", repr(state_entry)) - - txn.execute(PdusTable.insert_statement(), pdu_entry) - txn.execute(StatePdusTable.insert_statement(), state_entry) - - self._handle_prev_pdus( - txn, - pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus, - pdu_entry.context - ) - - def get_unresolved_state_tree(self, new_state_pdu): - return self.runInteraction( - self._get_unresolved_state_tree, new_state_pdu - ) - - @log_function - def _get_unresolved_state_tree(self, txn, new_pdu): - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - ReturnType = namedtuple( - "StateReturnType", ["new_branch", "current_branch"] - ) - return_value = ReturnType([new_pdu], []) - - if not current: - logger.debug("get_unresolved_state_tree No current state.") - return (return_value, None) - - return_value.current_branch.append(current) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - - missing_branch = None - for branch, prev_state, state in enum_branches: - if state: - return_value[branch].append(state) - else: - # We don't have prev_state :( - missing_branch = branch - break - - return (return_value, missing_branch) - - def update_current_state(self, pdu_id, origin, context, pdu_type, - state_key): - return self.runInteraction( - self._update_current_state, - pdu_id, origin, context, pdu_type, state_key - ) - - def _update_current_state(self, txn, pdu_id, origin, context, pdu_type, - state_key): - query = ( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - ) % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - } - - query_args = CurrentStateTable.EntryType( - pdu_id=pdu_id, - origin=origin, - context=context, - pdu_type=pdu_type, - state_key=state_key - ) - - txn.execute(query, query_args) - - def get_current_state_pdu(self, context, pdu_type, state_key): - """For a given context, pdu_type, state_key 3-tuple, return what is - currently considered the current state. - - Args: - txn - context (str) - pdu_type (str) - state_key (str) - - Returns: - PduEntry - """ - - return self.runInteraction( - self._get_current_state_pdu, context, pdu_type, state_key - ) - - def _get_current_state_pdu(self, txn, context, pdu_type, state_key): - return self._get_current_interaction(txn, context, pdu_type, state_key) - - def _get_current_interaction(self, txn, context, pdu_type, state_key): - logger.debug( - "_get_current_interaction %s %s %s", - context, pdu_type, state_key - ) - - fields = _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s") - - current_query = ( - "SELECT %(fields)s FROM %(state)s as s " - "INNER JOIN %(pdus)s as p " - "ON s.pdu_id = p.pdu_id AND s.origin = p.origin " - "INNER JOIN %(curr)s as c " - "ON s.pdu_id = c.pdu_id AND s.origin = c.origin " - "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? " - ) % { - "fields": fields, - "curr": CurrentStateTable.table_name, - "state": StatePdusTable.table_name, - "pdus": PdusTable.table_name, - } - - txn.execute( - current_query, - (context, pdu_type, state_key) - ) - - row = txn.fetchone() - - result = PduEntry(*row) if row else None - - if not result: - logger.debug("_get_current_interaction not found") - else: - logger.debug( - "_get_current_interaction found %s %s", - result.pdu_id, result.origin - ) - - return result - - def handle_new_state(self, new_pdu): - """Actually perform conflict resolution on the new_pdu on the - assumption we have all the pdus required to perform it. - - Args: - new_pdu - - Returns: - bool: True if the new_pdu clobbered the current state, False if not - """ - return self.runInteraction( - self._handle_new_state, new_pdu - ) - - def _handle_new_state(self, txn, new_pdu): - logger.debug( - "handle_new_state %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - is_current = False - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - # Oh, we don't have any state for this yet. - is_current = True - elif (current.pdu_id == new_pdu.prev_state_id - and current.origin == new_pdu.prev_state_origin): - # Oh! A direct clobber. Just do it. - is_current = True - else: - ## - # Ok, now loop through until we get to a common ancestor. - max_new = int(new_pdu.power_level) - max_current = int(current.power_level) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - for branch, prev_state, state in enum_branches: - if not state: - raise RuntimeError( - "Could not find state_pdu %s %s" % - ( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - ) - - if branch == 0: - max_new = max(int(state.depth), max_new) - else: - max_current = max(int(state.depth), max_current) - - is_current = max_new > max_current - - if is_current: - logger.debug("handle_new_state make current") - - # Right, this is a new thing, so woo, just insert it. - txn.execute( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - }, - CurrentStateTable.EntryType( - *(new_pdu.__dict__[k] for k in CurrentStateTable.fields) - ) - ) - else: - logger.debug("handle_new_state not current") - - logger.debug("handle_new_state done") - - return is_current - - @log_function - def _enumerate_state_branches(self, txn, pdu_a, pdu_b): - branch_a = pdu_a - branch_b = pdu_b - - while True: - if (branch_a.pdu_id == branch_b.pdu_id - and branch_a.origin == branch_b.origin): - # Woo! We found a common ancestor - logger.debug("_enumerate_state_branches Found common ancestor") - break - - do_branch_a = ( - hasattr(branch_a, "prev_state_id") and - branch_a.prev_state_id - ) - - do_branch_b = ( - hasattr(branch_b, "prev_state_id") and - branch_b.prev_state_id - ) - - logger.debug( - "do_branch_a=%s, do_branch_b=%s", - do_branch_a, do_branch_b - ) - - if do_branch_a and do_branch_b: - do_branch_a = int(branch_a.depth) > int(branch_b.depth) - - if do_branch_a: - pdu_tuple = PduIdTuple( - branch_a.prev_state_id, - branch_a.prev_state_origin - ) - - prev_branch = branch_a - - logger.debug("getting branch_a prev %s", pdu_tuple) - branch_a = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_a: - branch_a = Pdu.from_pdu_tuple(branch_a) - - logger.debug("branch_a=%s", branch_a) - - yield (0, prev_branch, branch_a) - - if not branch_a: - break - elif do_branch_b: - pdu_tuple = PduIdTuple( - branch_b.prev_state_id, - branch_b.prev_state_origin - ) - - prev_branch = branch_b - - logger.debug("getting branch_b prev %s", pdu_tuple) - branch_b = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_b: - branch_b = Pdu.from_pdu_tuple(branch_b) - - logger.debug("branch_b=%s", branch_b) - - yield (1, prev_branch, branch_b) - - if not branch_b: - break - else: - break - - -class PdusTable(Table): - table_name = "pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "ts", - "depth", - "is_state", - "content_json", - "unrecognized_keys", - "outlier", - "have_processed", - ] - - EntryType = namedtuple("PdusEntry", fields) - - -class PduDestinationsTable(Table): - table_name = "pdu_destinations" - - fields = [ - "pdu_id", - "origin", - "destination", - "delivered_ts", - ] - - EntryType = namedtuple("PduDestinationsEntry", fields) - - -class PduEdgesTable(Table): - table_name = "pdu_edges" - - fields = [ - "pdu_id", - "origin", - "prev_pdu_id", - "prev_origin", - "context" - ] - - EntryType = namedtuple("PduEdgesEntry", fields) - - -class PduForwardExtremitiesTable(Table): - table_name = "pdu_forward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduForwardExtremitiesEntry", fields) - - -class PduBackwardExtremitiesTable(Table): - table_name = "pdu_backward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduBackwardExtremitiesEntry", fields) - - -class ContextDepthTable(Table): - table_name = "context_depth" - - fields = [ - "context", - "min_depth", - ] - - EntryType = namedtuple("ContextDepthEntry", fields) - - -class StatePdusTable(Table): - table_name = "state_pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - "power_level", - "prev_state_id", - "prev_state_origin", - ] - - EntryType = namedtuple("StatePdusEntry", fields) - - -class CurrentStateTable(Table): - table_name = "current_state" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - ] - - EntryType = namedtuple("CurrentStateEntry", fields) - -_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable) - - -# TODO: These should probably be put somewhere more sensible -PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin")) - -PduEntry = _pdu_state_joiner.EntryType -""" We are always interested in the join of the PdusTable and StatePdusTable, -rather than just the PdusTable. - -This does not include a prev_pdus key. -""" - -PduTuple = namedtuple( - "PduTuple", - ("pdu_entry", "prev_pdu_list") -) -""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent -the `prev_pdus` key of a PDU. -""" |