# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Collection, Dict, List, Optional, cast
from unittest.mock import Mock

from twisted.internet import defer

from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
from synapse.util import Clock
from synapse.util.macaroons import MacaroonGenerator

from tests import unittest

from .utils import MockClock, default_config

_next_event_id = 1000


def create_event(
    name=None,
    type=None,
    state_key=None,
    depth=2,
    event_id=None,
    prev_events: Optional[List[str]] = None,
    **kwargs,
):
    global _next_event_id

    if not event_id:
        _next_event_id += 1
        event_id = "$%s:test" % (_next_event_id,)

    if not name:
        if state_key is not None:
            name = "<%s-%s, %s>" % (type, state_key, event_id)
        else:
            name = "<%s, %s>" % (type, event_id)

    d = {
        "event_id": event_id,
        "type": type,
        "sender": "@user_id:example.com",
        "room_id": "!room_id:example.com",
        "depth": depth,
        "prev_events": prev_events or [],
    }

    if state_key is not None:
        d["state_key"] = state_key

    d.update(kwargs)

    event = make_event_from_dict(d)

    return event


class _DummyStore:
    def __init__(self):
        self._event_to_state_group = {}
        self._group_to_state = {}

        self._event_id_to_event = {}

        self._next_group = 1

    async def get_state_groups_ids(self, room_id, event_ids):
        groups = {}
        for event_id in event_ids:
            group = self._event_to_state_group.get(event_id)
            if group:
                groups[group] = self._group_to_state[group]

        return groups

    async def get_state_ids_for_group(self, state_group, state_filter=None):
        return self._group_to_state[state_group]

    async def store_state_group(
        self, event_id, room_id, prev_group, delta_ids, current_state_ids
    ):
        state_group = self._next_group
        self._next_group += 1

        if current_state_ids is None:
            current_state_ids = dict(self._group_to_state[prev_group])
            current_state_ids.update(delta_ids)

        self._group_to_state[state_group] = dict(current_state_ids)

        return state_group

    async def get_events(self, event_ids, **kwargs):
        return {
            e_id: self._event_id_to_event[e_id]
            for e_id in event_ids
            if e_id in self._event_id_to_event
        }

    async def get_partial_state_events(
        self, event_ids: Collection[str]
    ) -> Dict[str, bool]:
        return {e: False for e in event_ids}

    async def get_state_group_delta(self, name):
        return None, None

    def register_events(self, events):
        for e in events:
            self._event_id_to_event[e.event_id] = e

    def register_event_context(self, event, context):
        self._event_to_state_group[event.event_id] = context.state_group

    def register_event_id_state_group(self, event_id, state_group):
        self._event_to_state_group[event_id] = state_group

    async def get_room_version_id(self, room_id):
        return RoomVersions.V1.identifier

    async def get_state_group_for_events(
        self, event_ids, await_full_state: bool = True
    ):
        res = {}
        for event in event_ids:
            res[event] = self._event_to_state_group[event]
        return res

    async def get_state_for_groups(self, groups):
        res = {}
        for group in groups:
            state = self._group_to_state[group]
            res[group] = state
        return res


class DictObj(dict):
    def __init__(self, **kwargs):
        super().__init__(kwargs)
        self.__dict__ = self


class Graph:
    def __init__(self, nodes, edges):
        events = {}
        clobbered = set(events.keys())

        for event_id, fields in nodes.items():
            refs = edges.get(event_id)
            if refs:
                clobbered.difference_update(refs)
                prev_events = [(r, {}) for r in refs]
            else:
                prev_events = []

            events[event_id] = create_event(
                event_id=event_id, prev_events=prev_events, **fields
            )

        self._leaves = clobbered
        self._events = sorted(events.values(), key=lambda e: e.depth)

    def walk(self):
        return iter(self._events)

    def get_leaves(self):
        return (self._events[i] for i in self._leaves)


class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.dummy_store = _DummyStore()
        storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
        hs = Mock(
            spec_set=[
                "config",
                "get_datastores",
                "get_storage_controllers",
                "get_auth",
                "get_state_handler",
                "get_clock",
                "get_state_resolution_handler",
                "get_account_validity_handler",
                "get_macaroon_generator",
                "get_instance_name",
                "get_simple_http_client",
                "hostname",
            ]
        )
        clock = cast(Clock, MockClock())
        hs.config = default_config("tesths", True)
        hs.get_datastores.return_value = Mock(main=self.dummy_store)
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = clock
        hs.get_macaroon_generator.return_value = MacaroonGenerator(
            clock, "tesths", b"verysecret"
        )
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
        hs.get_storage_controllers.return_value = storage_controllers

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_branch_no_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create, state_key="", content={}, depth=1
                ),
                "A": DictObj(type=EventTypes.Message, depth=2),
                "B": DictObj(type=EventTypes.Message, depth=3),
                "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
                "D": DictObj(type=EventTypes.Message, depth=4),
            },
            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
        )

        self.dummy_store.register_events(graph.walk())

        context_store: dict[str, EventContext] = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event)
            )
            self.dummy_store.register_event_context(event, context)
            context_store[event.event_id] = context

        ctx_c = context_store["C"]
        ctx_d = context_store["D"]

        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
        self.assertEqual(2, len(prev_state_ids))

        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

    @defer.inlineCallbacks
    def test_branch_basic_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
                "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
                "D": DictObj(type=EventTypes.Message, depth=5),
            },
            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
        )

        self.dummy_store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event)
            )
            self.dummy_store.register_event_context(event, context)
            context_store[event.event_id] = context

        # C ends up winning the resolution between B and C

        ctx_c = context_store["C"]
        ctx_d = context_store["D"]

        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
        self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))

        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

    @defer.inlineCallbacks
    def test_branch_have_banned_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
                "C": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id_2:example.com",
                    content={"membership": Membership.BAN},
                    membership=Membership.BAN,
                    depth=4,
                ),
                "D": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=4,
                    sender="@user_id_2:example.com",
                ),
                "E": DictObj(type=EventTypes.Message, depth=5),
            },
            edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
        )

        self.dummy_store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event)
            )
            self.dummy_store.register_event_context(event, context)
            context_store[event.event_id] = context

        # C ends up winning the resolution between C and D because bans win over other
        # changes

        ctx_c = context_store["C"]
        ctx_e = context_store["E"]

        prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
        self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
        self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
        self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)

    @defer.inlineCallbacks
    def test_branch_have_perms_conflict(self):
        userid1 = "@user_id:example.com"
        userid2 = "@user_id2:example.com"

        nodes = {
            "A1": DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": userid1},
                depth=1,
            ),
            "A2": DictObj(
                type=EventTypes.Member,
                state_key=userid1,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A3": DictObj(
                type=EventTypes.Member,
                state_key=userid2,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A4": DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {"m.room.name": 50},
                    "users": {userid1: 100, userid2: 60},
                },
            ),
            "A5": DictObj(type=EventTypes.Name, state_key=""),
            "B": DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
            ),
            "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
            "D": DictObj(type=EventTypes.Message),
        }
        edges = {
            "A2": ["A1"],
            "A3": ["A2"],
            "A4": ["A3"],
            "A5": ["A4"],
            "B": ["A5"],
            "C": ["A5"],
            "D": ["B", "C"],
        }
        self._add_depths(nodes, edges)
        graph = Graph(nodes, edges)

        self.dummy_store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event)
            )
            self.dummy_store.register_event_context(event, context)
            context_store[event.event_id] = context

        # B ends up winning the resolution between B and C because power levels
        # win over other changes.

        ctx_b = context_store["B"]
        ctx_d = context_store["D"]

        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
        self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))

        self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

    def _add_depths(self, nodes, edges):
        def _get_depth(ev):
            node = nodes[ev]
            if "depth" not in node:
                prevs = edges[ev]
                depth = max(_get_depth(prev) for prev in prevs) + 1
                node["depth"] = depth
            return node["depth"]

        for n in nodes:
            _get_depth(n)

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield defer.ensureDeferred(
            self.state.compute_event_context(
                event,
                state_ids_before_event={
                    (e.type, e.state_key): e.event_id for e in old_state
                },
                partial_state=False,
            )
        )

        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
        self.assertCountEqual(
            (e.event_id for e in old_state), current_state_ids.values()
        )

        self.assertIsNotNone(context.state_group_before_event)
        self.assertEqual(context.state_group_before_event, context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield defer.ensureDeferred(
            self.state.compute_event_context(
                event,
                state_ids_before_event={
                    (e.type, e.state_key): e.event_id for e in old_state
                },
                partial_state=False,
            )
        )

        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
        self.assertCountEqual(
            (e.event_id for e in old_state + [event]), current_state_ids.values()
        )

        self.assertIsNotNone(context.state_group_before_event)
        self.assertNotEqual(context.state_group_before_event, context.state_group)
        self.assertEqual(context.state_group_before_event, context.prev_group)
        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        prev_event_id = "prev_event_id"
        event = create_event(
            type="test_message", name="event2", prev_events=[(prev_event_id, {})]
        )

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = yield defer.ensureDeferred(
            self.dummy_store.store_state_group(
                prev_event_id,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id for e in old_state},
            )
        )
        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)

        context = yield defer.ensureDeferred(self.state.compute_event_context(event))

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())

        self.assertEqual(
            {e.event_id for e in old_state}, set(current_state_ids.values())
        )

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        prev_event_id = "prev_event_id"
        event = create_event(
            type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
        )

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = yield defer.ensureDeferred(
            self.dummy_store.store_state_group(
                prev_event_id,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id for e in old_state},
            )
        )
        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)

        context = yield defer.ensureDeferred(self.state.compute_event_context(event))

        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())

        self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test_message",
            name="event3",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        self.dummy_store.register_events(old_state_1)
        self.dummy_store.register_events(old_state_2)

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            state_key="",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        store = _DummyStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.dummy_store.get_events = store.get_events

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_standard_depth_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        member_event = create_event(
            type=EventTypes.Member,
            state_key="@user_id:example.com",
            content={"membership": Membership.JOIN},
        )

        power_levels = create_event(
            type=EventTypes.PowerLevels,
            state_key="",
            content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
        )

        creation = create_event(
            type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
        )

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        store = _DummyStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.dummy_store.get_events = store.get_events

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())

        self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])

        # Reverse the depth to make sure we are actually using the depths
        # during state resolution.

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        store.register_events(old_state_1)
        store.register_events(old_state_2)

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())

        self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])

    @defer.inlineCallbacks
    def _get_context(
        self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
    ):
        sg1 = yield defer.ensureDeferred(
            self.dummy_store.store_state_group(
                prev_event_id_1,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id for e in old_state_1},
            )
        )
        self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)

        sg2 = yield defer.ensureDeferred(
            self.dummy_store.store_state_group(
                prev_event_id_2,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id for e in old_state_2},
            )
        )
        self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)

        result = yield defer.ensureDeferred(self.state.compute_event_context(event))
        return result

    def test_make_state_cache_entry(self):
        "Test that calculating a prev_group and delta is correct"

        new_state = {
            ("a", ""): "E",
            ("b", ""): "E",
            ("c", ""): "E",
            ("d", ""): "E",
        }

        # old_state_1 has fewer differences to new_state than old_state_2, but
        # the delta involves deleting a key, which isn't allowed in the deltas,
        # so we should pick old_state_2 as the prev_group.

        # `old_state_1` has two differences: `a` and `e`
        old_state_1 = {
            ("a", ""): "F",
            ("b", ""): "E",
            ("c", ""): "E",
            ("d", ""): "E",
            ("e", ""): "E",
        }

        # `old_state_2` has three differences: `a`, `c` and `d`
        old_state_2 = {
            ("a", ""): "F",
            ("b", ""): "E",
            ("c", ""): "F",
            ("d", ""): "F",
        }

        entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})

        self.assertEqual(entry.prev_group, 2)

        # There are three changes from `old_state_2` to `new_state`
        self.assertEqual(
            entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
        )