diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 713a798703..5f83b637c5 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -18,6 +18,8 @@
#
#
import logging
+from copy import deepcopy
+from typing import Optional
from unittest.mock import patch
from parameterized import parameterized
@@ -33,20 +35,550 @@ from synapse.api.constants import (
RoomTypes,
)
from synapse.api.room_versions import RoomVersions
-from synapse.handlers.sliding_sync import SlidingSyncConfig
+from synapse.handlers.sliding_sync import RoomSyncConfig, StateValues
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, UserID
+from synapse.types.handlers import SlidingSyncConfig
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__)
+class RoomSyncConfigTestCase(TestCase):
+ def _assert_room_config_equal(
+ self,
+ actual: RoomSyncConfig,
+ expected: RoomSyncConfig,
+ message_prefix: Optional[str] = None,
+ ) -> None:
+ self.assertEqual(actual.timeline_limit, expected.timeline_limit, message_prefix)
+
+ # `self.assertEqual(...)` works fine to catch differences but the output is
+ # almost impossible to read because of the way it truncates the output and the
+ # order doesn't actually matter.
+ self.assertCountEqual(
+ actual.required_state_map, expected.required_state_map, message_prefix
+ )
+ for event_type, expected_state_keys in expected.required_state_map.items():
+ self.assertCountEqual(
+ actual.required_state_map[event_type],
+ expected_state_keys,
+ f"{message_prefix}: Mismatch for {event_type}",
+ )
+
+ @parameterized.expand(
+ [
+ (
+ "from_list_config",
+ """
+ Test that we can convert a `SlidingSyncConfig.SlidingSyncList` to a
+ `RoomSyncConfig`.
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.Member, "@bar"),
+ (EventTypes.Member, "@baz"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Name: {""},
+ EventTypes.Member: {
+ "@foo",
+ "@bar",
+ "@baz",
+ },
+ EventTypes.CanonicalAlias: {""},
+ },
+ ),
+ ),
+ (
+ "from_room_subscription",
+ """
+ Test that we can convert a `SlidingSyncConfig.RoomSubscription` to a
+ `RoomSyncConfig`.
+ """,
+ # Input
+ SlidingSyncConfig.RoomSubscription(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.Member, "@bar"),
+ (EventTypes.Member, "@baz"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Name: {""},
+ EventTypes.Member: {
+ "@foo",
+ "@bar",
+ "@baz",
+ },
+ EventTypes.CanonicalAlias: {""},
+ },
+ ),
+ ),
+ (
+ "wildcard",
+ """
+ Test that a wildcard (*) for both the `event_type` and `state_key` will override
+ all other values.
+
+ Note: MSC3575 describes different behavior to how we're handling things here but
+ since it's not wrong to return more state than requested (`required_state` is
+ just the minimum requested), it doesn't matter if we include things that the
+ client wanted excluded. This complexity is also under scrutiny, see
+ https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050
+
+ > One unique exception is when you request all state events via ["*", "*"]. When used,
+ > all state events are returned by default, and additional entries FILTER OUT the returned set
+ > of state events. These additional entries cannot use '*' themselves.
+ > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member
+ > event _except_ for @alice:example.com, and include every other state event.
+ > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not
+ > required as it would have been returned anyway.
+ >
+ > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575)
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (StateValues.WILDCARD, StateValues.WILDCARD),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD},
+ },
+ ),
+ ),
+ (
+ "wildcard_type",
+ """
+ Test that a wildcard (*) as a `event_type` will override all other values for the
+ same `state_key`.
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (StateValues.WILDCARD, ""),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {""},
+ EventTypes.Member: {"@foo"},
+ },
+ ),
+ ),
+ (
+ "multiple_wildcard_type",
+ """
+ Test that multiple wildcard (*) as a `event_type` will override all other values
+ for the same `state_key`.
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (StateValues.WILDCARD, ""),
+ (EventTypes.Member, "@foo"),
+ (StateValues.WILDCARD, "@foo"),
+ ("org.matrix.personal_count", "@foo"),
+ (EventTypes.Member, "@bar"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {
+ "",
+ "@foo",
+ },
+ EventTypes.Member: {"@bar"},
+ },
+ ),
+ ),
+ (
+ "wildcard_state_key",
+ """
+ Test that a wildcard (*) as a `state_key` will override all other values for the
+ same `event_type`.
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.Member, StateValues.WILDCARD),
+ (EventTypes.Member, "@bar"),
+ (EventTypes.Member, StateValues.LAZY),
+ (EventTypes.Member, "@baz"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Name: {""},
+ EventTypes.Member: {
+ StateValues.WILDCARD,
+ },
+ EventTypes.CanonicalAlias: {""},
+ },
+ ),
+ ),
+ (
+ "wildcard_merge",
+ """
+ Test that a wildcard (*) entries for the `event_type` and another one for
+ `state_key` will play together.
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (StateValues.WILDCARD, ""),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.Member, StateValues.WILDCARD),
+ (EventTypes.Member, "@bar"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {""},
+ EventTypes.Member: {StateValues.WILDCARD},
+ },
+ ),
+ ),
+ (
+ "wildcard_merge2",
+ """
+ Test that an all wildcard ("*", "*") entry will override any other
+ values (including other wildcards).
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (StateValues.WILDCARD, ""),
+ (EventTypes.Member, StateValues.WILDCARD),
+ (EventTypes.Member, "@foo"),
+ # One of these should take precedence over everything else
+ (StateValues.WILDCARD, StateValues.WILDCARD),
+ (StateValues.WILDCARD, StateValues.WILDCARD),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD},
+ },
+ ),
+ ),
+ (
+ "lazy_members",
+ """
+ `$LAZY` room members should just be another additional key next to other
+ explicit keys. We will unroll the special `$LAZY` meaning later.
+ """,
+ # Input
+ SlidingSyncConfig.SlidingSyncList(
+ timeline_limit=10,
+ required_state=[
+ (EventTypes.Name, ""),
+ (EventTypes.Member, "@foo"),
+ (EventTypes.Member, "@bar"),
+ (EventTypes.Member, StateValues.LAZY),
+ (EventTypes.Member, "@baz"),
+ (EventTypes.CanonicalAlias, ""),
+ ],
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Name: {""},
+ EventTypes.Member: {
+ "@foo",
+ "@bar",
+ StateValues.LAZY,
+ "@baz",
+ },
+ EventTypes.CanonicalAlias: {""},
+ },
+ ),
+ ),
+ ]
+ )
+ def test_from_room_config(
+ self,
+ _test_label: str,
+ _test_description: str,
+ room_params: SlidingSyncConfig.CommonRoomParameters,
+ expected_room_sync_config: RoomSyncConfig,
+ ) -> None:
+ """
+ Test `RoomSyncConfig.from_room_config(room_params)` will result in the `expected_room_sync_config`.
+ """
+ room_sync_config = RoomSyncConfig.from_room_config(room_params)
+
+ self._assert_room_config_equal(
+ room_sync_config,
+ expected_room_sync_config,
+ )
+
+ @parameterized.expand(
+ [
+ (
+ "no_direct_overlap",
+ # A
+ RoomSyncConfig(
+ timeline_limit=9,
+ required_state_map={
+ EventTypes.Name: {""},
+ EventTypes.Member: {
+ "@foo",
+ "@bar",
+ },
+ },
+ ),
+ # B
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@baz",
+ },
+ EventTypes.CanonicalAlias: {""},
+ },
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Name: {""},
+ EventTypes.Member: {
+ "@foo",
+ "@bar",
+ StateValues.LAZY,
+ "@baz",
+ },
+ EventTypes.CanonicalAlias: {""},
+ },
+ ),
+ ),
+ (
+ "wildcard_overlap",
+ # A
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD},
+ },
+ ),
+ # B
+ RoomSyncConfig(
+ timeline_limit=9,
+ required_state_map={
+ EventTypes.Dummy: {StateValues.WILDCARD},
+ StateValues.WILDCARD: {"@bar"},
+ EventTypes.Member: {"@foo"},
+ },
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD},
+ },
+ ),
+ ),
+ (
+ "state_type_wildcard_overlap",
+ # A
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Dummy: {"dummy"},
+ StateValues.WILDCARD: {
+ "",
+ "@foo",
+ },
+ EventTypes.Member: {"@bar"},
+ },
+ ),
+ # B
+ RoomSyncConfig(
+ timeline_limit=9,
+ required_state_map={
+ EventTypes.Dummy: {"dummy2"},
+ StateValues.WILDCARD: {
+ "",
+ "@bar",
+ },
+ EventTypes.Member: {"@foo"},
+ },
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Dummy: {
+ "dummy",
+ "dummy2",
+ },
+ StateValues.WILDCARD: {
+ "",
+ "@foo",
+ "@bar",
+ },
+ },
+ ),
+ ),
+ (
+ "state_key_wildcard_overlap",
+ # A
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Dummy: {"dummy"},
+ EventTypes.Member: {StateValues.WILDCARD},
+ "org.matrix.flowers": {StateValues.WILDCARD},
+ },
+ ),
+ # B
+ RoomSyncConfig(
+ timeline_limit=9,
+ required_state_map={
+ EventTypes.Dummy: {StateValues.WILDCARD},
+ EventTypes.Member: {StateValues.WILDCARD},
+ "org.matrix.flowers": {"tulips"},
+ },
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Dummy: {StateValues.WILDCARD},
+ EventTypes.Member: {StateValues.WILDCARD},
+ "org.matrix.flowers": {StateValues.WILDCARD},
+ },
+ ),
+ ),
+ (
+ "state_type_and_state_key_wildcard_merge",
+ # A
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Dummy: {"dummy"},
+ StateValues.WILDCARD: {
+ "",
+ "@foo",
+ },
+ EventTypes.Member: {"@bar"},
+ },
+ ),
+ # B
+ RoomSyncConfig(
+ timeline_limit=9,
+ required_state_map={
+ EventTypes.Dummy: {"dummy2"},
+ StateValues.WILDCARD: {""},
+ EventTypes.Member: {StateValues.WILDCARD},
+ },
+ ),
+ # Expected
+ RoomSyncConfig(
+ timeline_limit=10,
+ required_state_map={
+ EventTypes.Dummy: {
+ "dummy",
+ "dummy2",
+ },
+ StateValues.WILDCARD: {
+ "",
+ "@foo",
+ },
+ EventTypes.Member: {StateValues.WILDCARD},
+ },
+ ),
+ ),
+ ]
+ )
+ def test_combine_room_sync_config(
+ self,
+ _test_label: str,
+ a: RoomSyncConfig,
+ b: RoomSyncConfig,
+ expected: RoomSyncConfig,
+ ) -> None:
+ """
+ Combine A into B and B into A to make sure we get the same result.
+ """
+ # Since we're mutating these in place, make a copy for each of our trials
+ room_sync_config_a = deepcopy(a)
+ room_sync_config_b = deepcopy(b)
+
+ # Combine B into A
+ room_sync_config_a.combine_room_sync_config(room_sync_config_b)
+
+ self._assert_room_config_equal(room_sync_config_a, expected, "B into A")
+
+ # Since we're mutating these in place, make a copy for each of our trials
+ room_sync_config_a = deepcopy(a)
+ room_sync_config_b = deepcopy(b)
+
+ # Combine A into B
+ room_sync_config_b.combine_room_sync_config(room_sync_config_a)
+
+ self._assert_room_config_equal(room_sync_config_b, expected, "A into B")
+
+
class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it returns
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 966c622e14..cb2888409e 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -20,7 +20,7 @@
#
import json
import logging
-from typing import Dict, List
+from typing import AbstractSet, Any, Dict, Iterable, List, Optional
from parameterized import parameterized, parameterized_class
@@ -32,9 +32,12 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
HistoryVisibility,
+ Membership,
ReceiptTypes,
RelationTypes,
)
+from synapse.events import EventBase
+from synapse.handlers.sliding_sync import StateValues
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
@@ -45,6 +48,7 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
+from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__)
@@ -1237,6 +1241,94 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
)
self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources()
+ self.storage_controllers = hs.get_storage_controllers()
+
+ def _assertRequiredStateIncludes(
+ self,
+ actual_required_state: Any,
+ expected_state_events: Iterable[EventBase],
+ exact: bool = False,
+ ) -> None:
+ """
+ Wrapper around `_assertIncludes` to give slightly better looking diff error
+ messages that include some context "$event_id (type, state_key)".
+
+ Args:
+ actual_required_state: The "required_state" of a room from a Sliding Sync
+ request response.
+ expected_state_events: The expected state events to be included in the
+ `actual_required_state`.
+ exact: Whether the actual state should be exactly equal to the expected
+ state (no extras).
+ """
+
+ assert isinstance(actual_required_state, list)
+ for event in actual_required_state:
+ assert isinstance(event, dict)
+
+ self._assertIncludes(
+ {
+ f'{event["event_id"]} ("{event["type"]}", "{event["state_key"]}")'
+ for event in actual_required_state
+ },
+ {
+ f'{event.event_id} ("{event.type}", "{event.state_key}")'
+ for event in expected_state_events
+ },
+ exact=exact,
+ # Message to help understand the diff in context
+ message=str(actual_required_state),
+ )
+
+ def _assertIncludes(
+ self,
+ actual_items: AbstractSet[str],
+ expected_items: AbstractSet[str],
+ exact: bool = False,
+ message: Optional[str] = None,
+ ) -> None:
+ """
+ Assert that all of the `expected_items` are included in the `actual_items`.
+
+ This assert could also be called `assertContains`, `assertItemsInSet`
+
+ Args:
+ actual_items: The container
+ expected_items: The items to check for in the container
+ exact: Whether the actual state should be exactly equal to the expected
+ state (no extras).
+ message: Optional message to include in the failure message.
+ """
+ # Check that each set has the same items
+ if exact and actual_items == expected_items:
+ return
+ # Check for a superset
+ elif not exact and actual_items >= expected_items:
+ return
+
+ expected_lines: List[str] = []
+ for expected_item in expected_items:
+ is_expected_in_actual = expected_item in actual_items
+ expected_lines.append(
+ "{} {}".format(" " if is_expected_in_actual else "?", expected_item)
+ )
+
+ actual_lines: List[str] = []
+ for actual_item in actual_items:
+ is_actual_in_expected = actual_item in expected_items
+ actual_lines.append(
+ "{} {}".format("+" if is_actual_in_expected else " ", actual_item)
+ )
+
+ newline = "\n"
+ expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}"
+ actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}"
+ first_message = (
+ "Items must match exactly" if exact else "Some expected items are missing."
+ )
+ diff_message = f"{first_message}\n{expected_string}\n{actual_string}"
+
+ self.fail(f"{diff_message}\n{message}")
def _add_new_dm_to_global_account_data(
self, source_user_id: str, target_user_id: str, target_room_id: str
@@ -2091,6 +2183,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1].get("prev_batch"),
channel.json_body["rooms"][room_id1],
)
+ # `required_state` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("required_state"),
+ channel.json_body["rooms"][room_id1],
+ )
# We should have some `stripped_state` so the potential joiner can identify the
# room (we don't care about the order).
self.assertCountEqual(
@@ -2200,6 +2297,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1].get("prev_batch"),
channel.json_body["rooms"][room_id1],
)
+ # `required_state` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("required_state"),
+ channel.json_body["rooms"][room_id1],
+ )
# We should have some `stripped_state` so the potential joiner can identify the
# room (we don't care about the order).
self.assertCountEqual(
@@ -2321,6 +2423,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1].get("prev_batch"),
channel.json_body["rooms"][room_id1],
)
+ # `required_state` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("required_state"),
+ channel.json_body["rooms"][room_id1],
+ )
# We should have some `stripped_state` so the potential joiner can identify the
# room (we don't care about the order).
self.assertCountEqual(
@@ -2448,6 +2555,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1].get("prev_batch"),
channel.json_body["rooms"][room_id1],
)
+ # `required_state` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("required_state"),
+ channel.json_body["rooms"][room_id1],
+ )
# We should have some `stripped_state` so the potential joiner can identify the
# room (we don't care about the order).
self.assertCountEqual(
@@ -2681,3 +2793,602 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
False,
channel.json_body["rooms"][room_id1],
)
+
+ def test_rooms_no_required_state(self) -> None:
+ """
+ Empty `rooms.required_state` should not return any state events in the room
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ # Empty `required_state`
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # No `required_state` in response
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("required_state"),
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_required_state_initial_sync(self) -> None:
+ """
+ Test `rooms.required_state` returns requested state events in the room during an
+ initial sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.RoomHistoryVisibility, ""],
+ # This one doesn't exist in the room
+ [EventTypes.Tombstone, ""],
+ ],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ state_map[(EventTypes.RoomHistoryVisibility, "")],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_incremental_sync(self) -> None:
+ """
+ Test `rooms.required_state` returns requested state events in the room during an
+ incremental sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ after_room_token = self.event_sources.get_current_token()
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(after_room_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.RoomHistoryVisibility, ""],
+ # This one doesn't exist in the room
+ [EventTypes.Tombstone, ""],
+ ],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # The returned state doesn't change from initial to incremental sync. In the
+ # future, we will only return updates but only if we've sent the room down the
+ # connection before.
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ state_map[(EventTypes.RoomHistoryVisibility, "")],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_wildcard(self) -> None:
+ """
+ Test `rooms.required_state` returns all state events when using wildcard `["*", "*"]`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="namespaced",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+
+ # Make the Sliding Sync request with wildcards for the `event_type` and `state_key`
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [StateValues.WILDCARD, StateValues.WILDCARD],
+ ],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ # We should see all the state events in the room
+ state_map.values(),
+ exact=True,
+ )
+
+ def test_rooms_required_state_wildcard_event_type(self) -> None:
+ """
+ Test `rooms.required_state` returns relevant state events when using wildcard in
+ the event_type `["*", "foobarbaz"]`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key=user2_id,
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+
+ # Make the Sliding Sync request with wildcards for the `event_type`
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [StateValues.WILDCARD, user2_id],
+ ],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # We expect at-least any state event with the `user2_id` as the `state_key`
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[("org.matrix.foo_state", user2_id)],
+ },
+ # Ideally, this would be exact but we're currently returning all state
+ # events when the `event_type` is a wildcard.
+ exact=False,
+ )
+
+ def test_rooms_required_state_wildcard_state_key(self) -> None:
+ """
+ Test `rooms.required_state` returns relevant state events when using wildcard in
+ the state_key `["foobarbaz","*"]`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request with wildcards for the `state_key`
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Member, StateValues.WILDCARD],
+ ],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user1_id)],
+ state_map[(EventTypes.Member, user2_id)],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_lazy_loading_room_members(self) -> None:
+ """
+ Test `rooms.required_state` returns people relevant to the timeline when
+ lazy-loading room members, `["m.room.member","$LAZY"]`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user3_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request with lazy loading for the room members
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user3 sent events in the 3 events we see in the `timeline`
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[(EventTypes.Member, user3_id)],
+ },
+ exact=True,
+ )
+
+ @parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
+ def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None:
+ """
+ Test `rooms.required_state` should not return state past a leave/ban event.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ from_token = self.event_sources.get_current_token()
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+
+ if stop_membership == Membership.LEAVE:
+ # User 1 leaves
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ elif stop_membership == Membership.BAN:
+ # User 1 is banned
+ self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Change the state after user 1 leaves
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "qux"},
+ tok=user2_tok,
+ )
+ self.helper.leave(room_id1, user3_id, tok=user3_tok)
+
+ # Make the Sliding Sync request with lazy loading for the room members
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(from_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Only user2 and user3 sent events in the 3 events we see in the `timeline`
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ state_map[(EventTypes.Member, user1_id)],
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[(EventTypes.Member, user3_id)],
+ state_map[("org.matrix.foo_state", "")],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_combine_superset(self) -> None:
+ """
+ Test `rooms.required_state` is combined across lists and room subscriptions.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+
+ # Make the Sliding Sync request with wildcards for the `state_key`
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, user1_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "bar-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Member, StateValues.WILDCARD],
+ ["org.matrix.foo_state", ""],
+ ],
+ "timeline_limit": 0,
+ },
+ }
+ # TODO: Room subscription should also combine with the `required_state`
+ # "room_subscriptions": {
+ # room_id1: {
+ # "required_state": [
+ # ["org.matrix.bar_state", ""]
+ # ],
+ # "timeline_limit": 0,
+ # }
+ # }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ state_map[(EventTypes.Member, user1_id)],
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[("org.matrix.foo_state", "")],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_partial_state(self) -> None:
+ """
+ Test partially-stated room are excluded unless `rooms.required_state` is
+ lazy-loading room members.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ _join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
+
+ # Mark room2 as partial state
+ self.get_success(
+ mark_event_as_partial_state(self.hs, join_response2["event_id"], room_id2)
+ )
+
+ # Make the Sliding Sync request (NOT lazy-loading room members)
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 0,
+ },
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Make sure the list includes room1 but room2 is excluded because it's still
+ # partially-stated
+ self.assertListEqual(
+ list(channel.json_body["lists"]["foo-list"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 1],
+ "room_ids": [room_id1],
+ }
+ ],
+ channel.json_body["lists"]["foo-list"],
+ )
+
+ # Make the Sliding Sync request (with lazy-loading room members)
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ # Lazy-load room members
+ [EventTypes.Member, StateValues.LAZY],
+ ],
+ "timeline_limit": 0,
+ },
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # The list should include both rooms now because we're lazy-loading room members
+ self.assertListEqual(
+ list(channel.json_body["lists"]["foo-list"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 1],
+ "room_ids": [room_id2, room_id1],
+ }
+ ],
+ channel.json_body["lists"]["foo-list"],
+ )
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index fd03c23b89..35b3245708 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -125,13 +125,15 @@ async def mark_event_as_partial_state(
in this table).
"""
store = hs.get_datastores().main
- await store.db_pool.simple_upsert(
- table="partial_state_rooms",
- keyvalues={"room_id": room_id},
- values={},
- insertion_values={"room_id": room_id},
+ # Use the store helper to insert into the database so the caches are busted
+ await store.store_partial_state_room(
+ room_id=room_id,
+ servers={hs.hostname},
+ device_lists_stream_id=0,
+ joined_via=hs.hostname,
)
+ # FIXME: Bust the cache
await store.db_pool.simple_insert(
table="partial_state_events",
values={
|