summary refs log tree commit diff
path: root/tests/storage/databases/test_state_store.py
blob: 492ad858da2d5b99cc95316e20df8fee27281cf8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Dict, List, Sequence, Tuple
from unittest.mock import patch

from twisted.internet.defer import Deferred, ensureDeferred

from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap

from tests.unittest import HomeserverTestCase


class StateGroupInflightCachingTestCase(HomeserverTestCase):
    def setUp(self) -> None:
        super(StateGroupInflightCachingTestCase, self).setUp()
        # Patch out the `_get_state_groups_from_groups`.
        # This is useful because it lets us pretend we have a slow database.
        gsgfg_patch = patch(
            "synapse.storage.databases.state.store.StateGroupDataStore._get_state_groups_from_groups",
            self._fake_get_state_groups_from_groups,
        )
        gsgfg_patch.start()
        self.addCleanup(gsgfg_patch.stop)
        self.gsgfg_calls: List[
            Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
        ] = []

    def prepare(self, reactor, clock, homeserver) -> None:
        super(StateGroupInflightCachingTestCase, self).prepare(
            reactor, clock, homeserver
        )
        self.state_storage = homeserver.get_storage().state
        self.state_datastore = homeserver.get_datastores().state

    def _fake_get_state_groups_from_groups(
        self, groups: Sequence[int], state_filter: StateFilter
    ) -> "Deferred[Dict[int, StateMap[str]]]":
        print("hi", groups, state_filter)
        d: Deferred[Dict[int, StateMap[str]]] = Deferred()
        self.gsgfg_calls.append((tuple(groups), state_filter, d))
        return d

    def _complete_request_fake(
        self,
        groups: Tuple[int, ...],
        state_filter: StateFilter,
        d: "Deferred[Dict[int, StateMap[str]]]",
    ) -> None:
        """
        Assemble a fake database response and complete the database request.
        """

        result: Dict[int, StateMap[str]] = {}

        for group in groups:
            group_result: MutableStateMap[str] = {}
            result[group] = group_result

            for state_type, state_keys in state_filter.types.items():
                if state_keys is None:
                    group_result[
                        (state_type, "wild wombat")
                    ] = f"{group} {state_type} wild wombat"
                    group_result[
                        (state_type, "wild spqr")
                    ] = f"{group} {state_type} wild spqr"
                else:
                    for state_key in state_keys:
                        group_result[
                            (state_type, state_key)
                        ] = f"{group} {state_type} {state_key}"

            if state_filter.include_others:
                group_result[("something.else", "wild")] = "card"

        d.callback(result)

    def test_duplicate_requests_deduplicated(self) -> None:
        req1 = ensureDeferred(
            self.state_datastore._get_state_for_group_using_inflight_cache(
                42, StateFilter.all()
            )
        )
        self.pump(by=0.1)

        # This should have gone to the database
        self.assertEqual(len(self.gsgfg_calls), 1)
        self.assertFalse(req1.called)

        req2 = ensureDeferred(
            self.state_datastore._get_state_for_group_using_inflight_cache(
                42, StateFilter.all()
            )
        )
        self.pump(by=0.1)

        # No more calls should have gone to the database
        self.assertEqual(len(self.gsgfg_calls), 1)
        self.assertFalse(req1.called)
        self.assertFalse(req2.called)

        groups, sf, d = self.gsgfg_calls[0]
        self.assertEqual(groups, (42,))
        self.assertEqual(sf, StateFilter.all())

        # Now we can complete the request
        self._complete_request_fake(groups, sf, d)

        self.assertEqual(self.get_success(req1), {("something.else", "wild"): "card"})
        self.assertEqual(self.get_success(req2), {("something.else", "wild"): "card"})