# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer

from ._base import SQLBaseStore, Table, JoinHelper

from synapse.util.logutils import log_function

from collections import namedtuple

import logging

logger = logging.getLogger(__name__)


class PduStore(SQLBaseStore):
    """A collection of queries for handling PDUs.
    """

    def get_pdu(self, pdu_id, origin):
        """Given a pdu_id and origin, get a PDU.

        Args:
            txn
            pdu_id (str)
            origin (str)

        Returns:
            PduTuple: If the pdu does not exist in the database, returns None
        """

        return self._db_pool.runInteraction(
            self._get_pdu_tuple, pdu_id, origin
        )

    def _get_pdu_tuple(self, txn, pdu_id, origin):
        res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
        return res[0] if res else None

    def _get_pdu_tuples(self, txn, pdu_id_tuples):
        results = []
        for pdu_id, origin in pdu_id_tuples:
            txn.execute(
                PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
                (pdu_id, origin)
            )

            edges = [
                (r.prev_pdu_id, r.prev_origin)
                for r in PduEdgesTable.decode_results(txn.fetchall())
            ]

            query = (
                "SELECT %(fields)s FROM %(pdus)s as p "
                "LEFT JOIN %(state)s as s "
                "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
                "WHERE p.pdu_id = ? AND p.origin = ? "
            ) % {
                "fields": _pdu_state_joiner.get_fields(
                    PdusTable="p", StatePdusTable="s"),
                "pdus": PdusTable.table_name,
                "state": StatePdusTable.table_name,
            }

            txn.execute(query, (pdu_id, origin))

            row = txn.fetchone()
            if row:
                results.append(PduTuple(PduEntry(*row), edges))

        return results

    def get_current_state_for_context(self, context):
        """Get a list of PDUs that represent the current state for a given
        context

        Args:
            context (str)

        Returns:
            list: A list of PduTuples
        """

        return self._db_pool.runInteraction(
            self._get_current_state_for_context,
            context
        )

    def _get_current_state_for_context(self, txn, context):
        query = (
            "SELECT pdu_id, origin FROM %s WHERE context = ?"
            % CurrentStateTable.table_name
        )

        logger.debug("get_current_state %s, Args=%s", query, context)
        txn.execute(query, (context,))

        res = txn.fetchall()

        logger.debug("get_current_state %d results", len(res))

        return self._get_pdu_tuples(txn, res)

    def _persist_pdu_txn(self, txn, prev_pdus, cols):
        """Inserts a (non-state) PDU into the database.

        Args:
            txn,
            prev_pdus (list)
            **cols: The columns to insert into the PdusTable.
        """
        entry = PdusTable.EntryType(
            **{k: cols.get(k, None) for k in PdusTable.fields}
        )

        txn.execute(PdusTable.insert_statement(), entry)

        self._handle_prev_pdus(
            txn, entry.outlier, entry.pdu_id, entry.origin,
            prev_pdus, entry.context
        )

    def mark_pdu_as_processed(self, pdu_id, pdu_origin):
        """Mark a received PDU as processed.

        Args:
            txn
            pdu_id (str)
            pdu_origin (str)
        """

        return self._db_pool.runInteraction(
            self._mark_as_processed, pdu_id, pdu_origin
        )

    def _mark_as_processed(self, txn, pdu_id, pdu_origin):
        txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)

    def get_all_pdus_from_context(self, context):
        """Get a list of all PDUs for a given context."""
        return self._db_pool.runInteraction(
            self._get_all_pdus_from_context, context,
        )

    def _get_all_pdus_from_context(self, txn, context):
        query = (
            "SELECT pdu_id, origin FROM %s "
            "WHERE context = ?"
        ) % PdusTable.table_name

        txn.execute(query, (context,))

        return self._get_pdu_tuples(txn, txn.fetchall())

    def get_backfill(self, context, pdu_list, limit):
        """Get a list of Pdus for a given topic that occured before (and
        including) the pdus in pdu_list. Return a list of max size `limit`.

        Args:
            txn
            context (str)
            pdu_list (list)
            limit (int)

        Return:
            list: A list of PduTuples
        """
        return self._db_pool.runInteraction(
            self._get_backfill, context, pdu_list, limit
        )

    def _get_backfill(self, txn, context, pdu_list, limit):
        logger.debug(
            "backfill: %s, %s, %s",
            context, repr(pdu_list), limit
        )

        # We seed the pdu_results with the things from the pdu_list.
        pdu_results = pdu_list

        front = pdu_list

        query = (
            "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
            "WHERE context = ? AND pdu_id = ? AND origin = ? "
            "LIMIT ?"
        ) % {
            "edges_table": PduEdgesTable.table_name,
        }

        # We iterate through all pdu_ids in `front` to select their previous
        # pdus. These are dumped in `new_front`. We continue until we reach the
        # limit *or* new_front is empty (i.e., we've run out of things to
        # select
        while front and len(pdu_results) < limit:

            new_front = []
            for pdu_id, origin in front:
                logger.debug(
                    "_backfill_interaction: i=%s, o=%s",
                    pdu_id, origin
                )

                txn.execute(
                    query,
                    (context, pdu_id, origin, limit - len(pdu_results))
                )

                for row in txn.fetchall():
                    logger.debug(
                        "_backfill_interaction: got i=%s, o=%s",
                        *row
                    )
                    new_front.append(row)

            front = new_front
            pdu_results += new_front

        # We also want to update the `prev_pdus` attributes before returning.
        return self._get_pdu_tuples(txn, pdu_results)

    def get_min_depth_for_context(self, context):
        """Get the current minimum depth for a context

        Args:
            txn
            context (str)
        """
        return self._db_pool.runInteraction(
            self._get_min_depth_for_context, context
        )

    def _get_min_depth_for_context(self, txn, context):
        return self._get_min_depth_interaction(txn, context)

    def _get_min_depth_interaction(self, txn, context):
        txn.execute(
            "SELECT min_depth FROM %s WHERE context = ?"
            % ContextDepthTable.table_name,
            (context,)
        )

        row = txn.fetchone()

        return row[0] if row else None

    def _update_min_depth_for_context_txn(self, txn, context, depth):
        """Update the minimum `depth` of the given context, which is the line
        on which we stop backfilling backwards.

        Args:
            context (str)
            depth (int)
        """
        min_depth = self._get_min_depth_interaction(txn, context)

        do_insert = depth < min_depth if min_depth else True

        if do_insert:
            txn.execute(
                "INSERT OR REPLACE INTO %s (context, min_depth) "
                "VALUES (?,?)" % ContextDepthTable.table_name,
                (context, depth)
            )

    def _get_latest_pdus_in_context(self, txn, context):
        """Get's a list of the most current pdus for a given context. This is
        used when we are sending a Pdu and need to fill out the `prev_pdus`
        key

        Args:
            txn
            context
        """
        query = (
            "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
            "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
            "AND f.origin = p.origin "
            "WHERE f.context = ?"
        ) % {
            "pdus": PdusTable.table_name,
            "forward": PduForwardExtremitiesTable.table_name,
        }

        logger.debug("get_prev query: %s", query)

        txn.execute(
            query,
            (context, )
        )

        results = txn.fetchall()

        return [(row[0], row[1], row[2]) for row in results]

    @defer.inlineCallbacks
    def get_oldest_pdus_in_context(self, context):
        """Get a list of Pdus that we haven't backfilled beyond yet (and havent
        seen). This list is used when we want to backfill backwards and is the
        list we send to the remote server.

        Args:
            txn
            context (str)

        Returns:
            list: A list of PduIdTuple.
        """
        results = yield self._execute(
            None,
            "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
            % {"back": PduBackwardExtremitiesTable.table_name, },
            context
        )

        defer.returnValue([PduIdTuple(i, o) for i, o in results])

    def is_pdu_new(self, pdu_id, origin, context, depth):
        """For a given Pdu, try and figure out if it's 'new', i.e., if it's
        not something we got randomly from the past, for example when we
        request the current state of the room that will probably return a bunch
        of pdus from before we joined.

        Args:
            txn
            pdu_id (str)
            origin (str)
            context (str)
            depth (int)

        Returns:
            bool
        """

        return self._db_pool.runInteraction(
            self._is_pdu_new,
            pdu_id=pdu_id,
            origin=origin,
            context=context,
            depth=depth
        )

    def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
        # If depth > min depth in back table, then we classify it as new.
        # OR if there is nothing in the back table, then it kinda needs to
        # be a new thing.
        query = (
            "SELECT min(p.depth) FROM %(edges)s as e "
            "INNER JOIN %(back)s as b "
            "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
            "INNER JOIN %(pdus)s as p "
            "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
            "WHERE p.context = ?"
        ) % {
            "pdus": PdusTable.table_name,
            "edges": PduEdgesTable.table_name,
            "back": PduBackwardExtremitiesTable.table_name,
        }

        txn.execute(query, (context,))

        min_depth, = txn.fetchone()

        if not min_depth or depth > int(min_depth):
            logger.debug(
                "is_new true: id=%s, o=%s, d=%s min_depth=%s",
                pdu_id, origin, depth, min_depth
            )
            return True

        # If this pdu is in the forwards table, then it also is a new one
        query = (
            "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
        ) % {
            "forward": PduForwardExtremitiesTable.table_name,
        }

        txn.execute(query, (pdu_id, origin))

        # Did we get anything?
        if txn.fetchall():
            logger.debug(
                "is_new true: id=%s, o=%s, d=%s was forward",
                pdu_id, origin, depth
            )
            return True

        logger.debug(
            "is_new false: id=%s, o=%s, d=%s",
            pdu_id, origin, depth
        )

        # FINE THEN. It's probably old.
        return False

    @staticmethod
    @log_function
    def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
                          context):
        txn.executemany(
            PduEdgesTable.insert_statement(),
            [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
        )

        # Update the extremities table if this is not an outlier.
        if not outlier:

            # First, we delete the new one from the forwards extremities table.
            query = (
                "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
                % PduForwardExtremitiesTable.table_name
            )
            txn.executemany(query, prev_pdus)

            # We only insert as a forward extremety the new pdu if there are no
            # other pdus that reference it as a prev pdu
            query = (
                "INSERT INTO %(table)s (pdu_id, origin, context) "
                "SELECT ?, ?, ? WHERE NOT EXISTS ("
                "SELECT 1 FROM %(pdu_edges)s WHERE "
                "prev_pdu_id = ? AND prev_origin = ?"
                ")"
            ) % {
                "table": PduForwardExtremitiesTable.table_name,
                "pdu_edges": PduEdgesTable.table_name
            }

            logger.debug("query: %s", query)

            txn.execute(query, (pdu_id, origin, context, pdu_id, origin))

            # Insert all the prev_pdus as a backwards thing, they'll get
            # deleted in a second if they're incorrect anyway.
            txn.executemany(
                PduBackwardExtremitiesTable.insert_statement(),
                [(i, o, context) for i, o in prev_pdus]
            )

            # Also delete from the backwards extremities table all ones that
            # reference pdus that we have already seen
            query = (
                "DELETE FROM %(pdu_back)s WHERE EXISTS ("
                "SELECT 1 FROM %(pdus)s AS pdus "
                "WHERE "
                "%(pdu_back)s.pdu_id = pdus.pdu_id "
                "AND %(pdu_back)s.origin = pdus.origin "
                "AND not pdus.outlier "
                ")"
            ) % {
                "pdu_back": PduBackwardExtremitiesTable.table_name,
                "pdus": PdusTable.table_name,
            }
            txn.execute(query)


class StatePduStore(SQLBaseStore):
    """A collection of queries for handling state PDUs.
    """

    def _persist_state_txn(self, txn, prev_pdus, cols):
        """Inserts a state PDU into the database

        Args:
            txn,
            prev_pdus (list)
            **cols: The columns to insert into the PdusTable and StatePdusTable
        """
        pdu_entry = PdusTable.EntryType(
            **{k: cols.get(k, None) for k in PdusTable.fields}
        )
        state_entry = StatePdusTable.EntryType(
            **{k: cols.get(k, None) for k in StatePdusTable.fields}
        )

        logger.debug("Inserting pdu: %s", repr(pdu_entry))
        logger.debug("Inserting state: %s", repr(state_entry))

        txn.execute(PdusTable.insert_statement(), pdu_entry)
        txn.execute(StatePdusTable.insert_statement(), state_entry)

        self._handle_prev_pdus(
            txn,
            pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
            pdu_entry.context
        )

    def get_unresolved_state_tree(self, new_state_pdu):
        return self._db_pool.runInteraction(
            self._get_unresolved_state_tree, new_state_pdu
        )

    @log_function
    def _get_unresolved_state_tree(self, txn, new_pdu):
        current = self._get_current_interaction(
            txn,
            new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
        )

        ReturnType = namedtuple(
            "StateReturnType", ["new_branch", "current_branch"]
        )
        return_value = ReturnType([new_pdu], [])

        if not current:
            logger.debug("get_unresolved_state_tree No current state.")
            return (return_value, None)

        return_value.current_branch.append(current)

        enum_branches = self._enumerate_state_branches(
            txn, new_pdu, current
        )

        missing_branch = None
        for branch, prev_state, state in enum_branches:
            if state:
                return_value[branch].append(state)
            else:
                # We don't have prev_state :(
                missing_branch = branch
                break

        return (return_value, missing_branch)

    def update_current_state(self, pdu_id, origin, context, pdu_type,
                             state_key):
        return self._db_pool.runInteraction(
            self._update_current_state,
            pdu_id, origin, context, pdu_type, state_key
        )

    def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
                              state_key):
        query = (
            "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
        ) % {
            "curr": CurrentStateTable.table_name,
            "fields": CurrentStateTable.get_fields_string(),
            "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
        }

        query_args = CurrentStateTable.EntryType(
            pdu_id=pdu_id,
            origin=origin,
            context=context,
            pdu_type=pdu_type,
            state_key=state_key
        )

        txn.execute(query, query_args)

    def get_current_state_pdu(self, context, pdu_type, state_key):
        """For a given context, pdu_type, state_key 3-tuple, return what is
        currently considered the current state.

        Args:
            txn
            context (str)
            pdu_type (str)
            state_key (str)

        Returns:
            PduEntry
        """

        return self._db_pool.runInteraction(
            self._get_current_state_pdu, context, pdu_type, state_key
        )

    def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
        return self._get_current_interaction(txn, context, pdu_type, state_key)

    def _get_current_interaction(self, txn, context, pdu_type, state_key):
        logger.debug(
            "_get_current_interaction %s %s %s",
            context, pdu_type, state_key
        )

        fields = _pdu_state_joiner.get_fields(
            PdusTable="p", StatePdusTable="s")

        current_query = (
            "SELECT %(fields)s FROM %(state)s as s "
            "INNER JOIN %(pdus)s as p "
            "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
            "INNER JOIN %(curr)s as c "
            "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
            "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
        ) % {
            "fields": fields,
            "curr": CurrentStateTable.table_name,
            "state": StatePdusTable.table_name,
            "pdus": PdusTable.table_name,
        }

        txn.execute(
            current_query,
            (context, pdu_type, state_key)
        )

        row = txn.fetchone()

        result = PduEntry(*row) if row else None

        if not result:
            logger.debug("_get_current_interaction not found")
        else:
            logger.debug(
                "_get_current_interaction found %s %s",
                result.pdu_id, result.origin
            )

        return result

    def get_next_missing_pdu(self, new_pdu):
        """When we get a new state pdu we need to check whether we need to do
        any conflict resolution, if we do then we need to check if we need
        to go back and request some more state pdus that we haven't seen yet.

        Args:
            txn
            new_pdu

        Returns:
            PduIdTuple: A pdu that we are missing, or None if we have all the
                pdus required to do the conflict resolution.
        """
        return self._db_pool.runInteraction(
            self._get_next_missing_pdu, new_pdu
        )

    def _get_next_missing_pdu(self, txn, new_pdu):
        logger.debug(
            "get_next_missing_pdu %s %s",
            new_pdu.pdu_id, new_pdu.origin
        )

        current = self._get_current_interaction(
            txn,
            new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
        )

        if (not current or not current.prev_state_id
                or not current.prev_state_origin):
            return None

        # Oh look, it's a straight clobber, so wooooo almost no-op.
        if (new_pdu.prev_state_id == current.pdu_id
                and new_pdu.prev_state_origin == current.origin):
            return None

        enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
        for branch, prev_state, state in enum_branches:
            if not state:
                return PduIdTuple(
                    prev_state.prev_state_id,
                    prev_state.prev_state_origin
                )

        return None

    def handle_new_state(self, new_pdu):
        """Actually perform conflict resolution on the new_pdu on the
        assumption we have all the pdus required to perform it.

        Args:
            new_pdu

        Returns:
            bool: True if the new_pdu clobbered the current state, False if not
        """
        return self._db_pool.runInteraction(
            self._handle_new_state, new_pdu
        )

    def _handle_new_state(self, txn, new_pdu):
        logger.debug(
            "handle_new_state %s %s",
            new_pdu.pdu_id, new_pdu.origin
        )

        current = self._get_current_interaction(
            txn,
            new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
        )

        is_current = False

        if (not current or not current.prev_state_id
                or not current.prev_state_origin):
            # Oh, we don't have any state for this yet.
            is_current = True
        elif (current.pdu_id == new_pdu.prev_state_id
                and current.origin == new_pdu.prev_state_origin):
            # Oh! A direct clobber. Just do it.
            is_current = True
        else:
            ##
            # Ok, now loop through until we get to a common ancestor.
            max_new = int(new_pdu.power_level)
            max_current = int(current.power_level)

            enum_branches = self._enumerate_state_branches(
                txn, new_pdu, current
            )
            for branch, prev_state, state in enum_branches:
                if not state:
                    raise RuntimeError(
                        "Could not find state_pdu %s %s" %
                        (
                            prev_state.prev_state_id,
                            prev_state.prev_state_origin
                        )
                    )

                if branch == 0:
                    max_new = max(int(state.depth), max_new)
                else:
                    max_current = max(int(state.depth), max_current)

            is_current = max_new > max_current

        if is_current:
            logger.debug("handle_new_state make current")

            # Right, this is a new thing, so woo, just insert it.
            txn.execute(
                "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
                % {
                    "curr": CurrentStateTable.table_name,
                    "fields": CurrentStateTable.get_fields_string(),
                    "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
                },
                CurrentStateTable.EntryType(
                    *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
                )
            )
        else:
            logger.debug("handle_new_state not current")

        logger.debug("handle_new_state done")

        return is_current

    @classmethod
    @log_function
    def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
        branch_a = pdu_a
        branch_b = pdu_b

        get_query = (
            "SELECT %(fields)s FROM %(pdus)s as p "
            "LEFT JOIN %(state)s as s "
            "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
            "WHERE p.pdu_id = ? AND p.origin = ? "
        ) % {
            "fields": _pdu_state_joiner.get_fields(
                PdusTable="p", StatePdusTable="s"),
            "pdus": PdusTable.table_name,
            "state": StatePdusTable.table_name,
        }

        while True:
            if (branch_a.pdu_id == branch_b.pdu_id
                    and branch_a.origin == branch_b.origin):
                # Woo! We found a common ancestor
                logger.debug("_enumerate_state_branches Found common ancestor")
                break

            do_branch_a = (
                hasattr(branch_a, "prev_state_id") and
                branch_a.prev_state_id
            )

            do_branch_b = (
                hasattr(branch_b, "prev_state_id") and
                branch_b.prev_state_id
            )

            logger.debug(
                "do_branch_a=%s, do_branch_b=%s",
                do_branch_a, do_branch_b
            )

            if do_branch_a and do_branch_b:
                do_branch_a = int(branch_a.depth) > int(branch_b.depth)

            if do_branch_a:
                pdu_tuple = PduIdTuple(
                    branch_a.prev_state_id,
                    branch_a.prev_state_origin
                )

                logger.debug("getting branch_a prev %s", pdu_tuple)
                txn.execute(get_query, pdu_tuple)

                prev_branch = branch_a

                res = txn.fetchone()
                branch_a = PduEntry(*res) if res else None

                logger.debug("branch_a=%s", branch_a)

                yield (0, prev_branch, branch_a)

                if not branch_a:
                    break
            elif do_branch_b:
                pdu_tuple = PduIdTuple(
                    branch_b.prev_state_id,
                    branch_b.prev_state_origin
                )
                txn.execute(get_query, pdu_tuple)

                logger.debug("getting branch_b prev %s", pdu_tuple)

                prev_branch = branch_b

                res = txn.fetchone()
                branch_b = PduEntry(*res) if res else None

                logger.debug("branch_b=%s", branch_b)

                yield (1, prev_branch, branch_b)

                if not branch_b:
                    break
            else:
                break


class PdusTable(Table):
    table_name = "pdus"

    fields = [
        "pdu_id",
        "origin",
        "context",
        "pdu_type",
        "ts",
        "depth",
        "is_state",
        "content_json",
        "unrecognized_keys",
        "outlier",
        "have_processed",
    ]

    EntryType = namedtuple("PdusEntry", fields)


class PduDestinationsTable(Table):
    table_name = "pdu_destinations"

    fields = [
        "pdu_id",
        "origin",
        "destination",
        "delivered_ts",
    ]

    EntryType = namedtuple("PduDestinationsEntry", fields)


class PduEdgesTable(Table):
    table_name = "pdu_edges"

    fields = [
        "pdu_id",
        "origin",
        "prev_pdu_id",
        "prev_origin",
        "context"
    ]

    EntryType = namedtuple("PduEdgesEntry", fields)


class PduForwardExtremitiesTable(Table):
    table_name = "pdu_forward_extremities"

    fields = [
        "pdu_id",
        "origin",
        "context",
    ]

    EntryType = namedtuple("PduForwardExtremitiesEntry", fields)


class PduBackwardExtremitiesTable(Table):
    table_name = "pdu_backward_extremities"

    fields = [
        "pdu_id",
        "origin",
        "context",
    ]

    EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)


class ContextDepthTable(Table):
    table_name = "context_depth"

    fields = [
        "context",
        "min_depth",
    ]

    EntryType = namedtuple("ContextDepthEntry", fields)


class StatePdusTable(Table):
    table_name = "state_pdus"

    fields = [
        "pdu_id",
        "origin",
        "context",
        "pdu_type",
        "state_key",
        "power_level",
        "prev_state_id",
        "prev_state_origin",
    ]

    EntryType = namedtuple("StatePdusEntry", fields)


class CurrentStateTable(Table):
    table_name = "current_state"

    fields = [
        "pdu_id",
        "origin",
        "context",
        "pdu_type",
        "state_key",
    ]

    EntryType = namedtuple("CurrentStateEntry", fields)

_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)


# TODO: These should probably be put somewhere more sensible
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))

PduEntry = _pdu_state_joiner.EntryType
""" We are always interested in the join of the PdusTable and StatePdusTable,
rather than just the PdusTable.

This does not include a prev_pdus key.
"""

PduTuple = namedtuple(
    "PduTuple",
    ("pdu_entry", "prev_pdu_list")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
"""