summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py65
1 files changed, 61 insertions, 4 deletions
diff --git a/tests/test_state.py b/tests/test_state.py

index 95f81bebae..504530b49a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Optional +from typing import Collection, Dict, List, Optional, cast from unittest.mock import Mock from twisted.internet import defer @@ -21,7 +21,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.state import StateHandler, StateResolutionHandler +from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry +from synapse.util import Clock +from synapse.util.macaroons import MacaroonGenerator from tests import unittest @@ -97,6 +99,10 @@ class _DummyStore: state_group = self._next_group self._next_group += 1 + if current_state_ids is None: + current_state_ids = dict(self._group_to_state[prev_group]) + current_state_ids.update(delta_ids) + self._group_to_state[state_group] = dict(current_state_ids) return state_group @@ -129,7 +135,9 @@ class _DummyStore: async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier - async def get_state_group_for_events(self, event_ids): + async def get_state_group_for_events( + self, event_ids, await_full_state: bool = True + ): res = {} for event in event_ids: res[event] = self._event_to_state_group[event] @@ -190,13 +198,20 @@ class StateTestCase(unittest.TestCase): "get_clock", "get_state_resolution_handler", "get_account_validity_handler", + "get_macaroon_generator", + "get_instance_name", + "get_simple_http_client", "hostname", ] ) + clock = cast(Clock, MockClock()) hs.config = default_config("tesths", True) hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None - hs.get_clock.return_value = MockClock() + hs.get_clock.return_value = clock + hs.get_macaroon_generator.return_value = MacaroonGenerator( + clock, "tesths", b"verysecret" + ) hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage_controllers.return_value = storage_controllers @@ -447,6 +462,7 @@ class StateTestCase(unittest.TestCase): state_ids_before_event={ (e.type, e.state_key): e.event_id for e in old_state }, + partial_state=False, ) ) @@ -477,6 +493,7 @@ class StateTestCase(unittest.TestCase): state_ids_before_event={ (e.type, e.state_key): e.event_id for e in old_state }, + partial_state=False, ) ) @@ -749,3 +766,43 @@ class StateTestCase(unittest.TestCase): result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result + + def test_make_state_cache_entry(self): + "Test that calculating a prev_group and delta is correct" + + new_state = { + ("a", ""): "E", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + } + + # old_state_1 has fewer differences to new_state than old_state_2, but + # the delta involves deleting a key, which isn't allowed in the deltas, + # so we should pick old_state_2 as the prev_group. + + # `old_state_1` has two differences: `a` and `e` + old_state_1 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + ("e", ""): "E", + } + + # `old_state_2` has three differences: `a`, `c` and `d` + old_state_2 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "F", + ("d", ""): "F", + } + + entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2}) + + self.assertEqual(entry.prev_group, 2) + + # There are three changes from `old_state_2` to `new_state` + self.assertEqual( + entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"} + )