diff --git a/tests/test_state.py b/tests/test_state.py
index 95f81bebae..504530b49a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Dict, List, Optional
+from typing import Collection, Dict, List, Optional, cast
from unittest.mock import Mock
from twisted.internet import defer
@@ -21,7 +21,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.state import StateHandler, StateResolutionHandler
+from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
+from synapse.util import Clock
+from synapse.util.macaroons import MacaroonGenerator
from tests import unittest
@@ -97,6 +99,10 @@ class _DummyStore:
state_group = self._next_group
self._next_group += 1
+ if current_state_ids is None:
+ current_state_ids = dict(self._group_to_state[prev_group])
+ current_state_ids.update(delta_ids)
+
self._group_to_state[state_group] = dict(current_state_ids)
return state_group
@@ -129,7 +135,9 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
- async def get_state_group_for_events(self, event_ids):
+ async def get_state_group_for_events(
+ self, event_ids, await_full_state: bool = True
+ ):
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
@@ -190,13 +198,20 @@ class StateTestCase(unittest.TestCase):
"get_clock",
"get_state_resolution_handler",
"get_account_validity_handler",
+ "get_macaroon_generator",
+ "get_instance_name",
+ "get_simple_http_client",
"hostname",
]
)
+ clock = cast(Clock, MockClock())
hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock(main=self.dummy_store)
hs.get_state_handler.return_value = None
- hs.get_clock.return_value = MockClock()
+ hs.get_clock.return_value = clock
+ hs.get_macaroon_generator.return_value = MacaroonGenerator(
+ clock, "tesths", b"verysecret"
+ )
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage_controllers.return_value = storage_controllers
@@ -447,6 +462,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
+ partial_state=False,
)
)
@@ -477,6 +493,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
+ partial_state=False,
)
)
@@ -749,3 +766,43 @@ class StateTestCase(unittest.TestCase):
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result
+
+ def test_make_state_cache_entry(self):
+ "Test that calculating a prev_group and delta is correct"
+
+ new_state = {
+ ("a", ""): "E",
+ ("b", ""): "E",
+ ("c", ""): "E",
+ ("d", ""): "E",
+ }
+
+ # old_state_1 has fewer differences to new_state than old_state_2, but
+ # the delta involves deleting a key, which isn't allowed in the deltas,
+ # so we should pick old_state_2 as the prev_group.
+
+ # `old_state_1` has two differences: `a` and `e`
+ old_state_1 = {
+ ("a", ""): "F",
+ ("b", ""): "E",
+ ("c", ""): "E",
+ ("d", ""): "E",
+ ("e", ""): "E",
+ }
+
+ # `old_state_2` has three differences: `a`, `c` and `d`
+ old_state_2 = {
+ ("a", ""): "F",
+ ("b", ""): "E",
+ ("c", ""): "F",
+ ("d", ""): "F",
+ }
+
+ entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
+
+ self.assertEqual(entry.prev_group, 2)
+
+ # There are three changes from `old_state_2` to `new_state`
+ self.assertEqual(
+ entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
+ )
|