diff --git a/changelog.d/17915.bugfix b/changelog.d/17915.bugfix
new file mode 100644
index 0000000000..a5d82e486d
--- /dev/null
+++ b/changelog.d/17915.bugfix
@@ -0,0 +1 @@
+Fix experimental support for [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222) where we would return the full state on incremental syncs when using lazy loaded members and there were no new events in the timeline.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 204965afee..df3010ecf6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -196,7 +196,9 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view
members of this room.
"""
- state_filter = state_filter or StateFilter.all()
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
user_id = requester.user.to_string()
if at_token:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index df9a088063..350c3fa09a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1520,7 +1520,7 @@ class SyncHandler:
if sync_config.use_state_after:
delta_state_ids: MutableStateMap[str] = {}
- if members_to_fetch is not None:
+ if members_to_fetch:
# We're lazy-loading, so the client might need some more member
# events to understand the events in this timeline. So we always
# fish out all the member events corresponding to the timeline
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index b50eb8868e..f28f5d7e03 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -234,8 +234,11 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -244,7 +247,7 @@ class StateStorageController:
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
+ groups, state_filter
)
state_event_map = await self.stores.main.get_events(
@@ -292,10 +295,11 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- if (
- await_full_state
- and state_filter
- and not state_filter.must_await_full_state(self._is_mine_id)
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
+ if await_full_state and not state_filter.must_await_full_state(
+ self._is_mine_id
):
# Full state is not required if the state filter is restrictive enough.
await_full_state = False
@@ -306,7 +310,7 @@ class StateStorageController:
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
+ groups, state_filter
)
event_to_state = {
@@ -335,9 +339,10 @@ class StateStorageController:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
- state_map = await self.get_state_for_events(
- [event_id], state_filter or StateFilter.all()
- )
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
+ state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
@trace
@@ -365,9 +370,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
state_map = await self.get_state_ids_for_events(
[event_id],
- state_filter or StateFilter.all(),
+ state_filter,
await_full_state=await_full_state,
)
return state_map[event_id]
@@ -388,9 +396,12 @@ class StateStorageController:
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
"""
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
state_ids = await self.get_state_ids_for_event(
event_id,
- state_filter=state_filter or StateFilter.all(),
+ state_filter=state_filter,
await_full_state=await_full_state,
)
@@ -426,6 +437,9 @@ class StateStorageController:
at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`.
"""
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
@@ -442,7 +456,7 @@ class StateStorageController:
if last_event_id:
state = await self.get_state_after_event(
last_event_id,
- state_filter=state_filter or StateFilter.all(),
+ state_filter=state_filter,
await_full_state=await_full_state,
)
@@ -500,9 +514,10 @@ class StateStorageController:
Returns:
Dict of state group to state map.
"""
- return await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
+ return await self.stores.state._get_state_for_groups(groups, state_filter)
@trace
@tag_args
@@ -583,12 +598,13 @@ class StateStorageController:
Returns:
The current state of the room.
"""
- if await_full_state and (
- not state_filter or state_filter.must_await_full_state(self._is_mine_id)
- ):
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
+ if await_full_state and state_filter.must_await_full_state(self._is_mine_id):
await self._partial_state_room_tracker.await_full_state(room_id)
- if state_filter and not state_filter.is_full():
+ if state_filter is not None and not state_filter.is_full():
return await self.stores.main.get_partial_filtered_current_state_ids(
room_id, state_filter
)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 42b3638e1c..788f7d1e32 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -572,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Map from type/state_key to event ID.
"""
+ if state_filter is None:
+ state_filter = StateFilter.all()
- where_clause, where_args = (
- state_filter or StateFilter.all()
- ).make_sql_filter_clause()
+ where_clause, where_args = (state_filter).make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
@@ -584,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
- results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
+ results = StateMapWrapper(state_filter=state_filter)
sql = """
SELECT type, state_key, event_id FROM current_state_events
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index ea7d8199a7..f7824cba0f 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -112,8 +112,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
Returns:
Map from state_group to a StateMap at that point.
"""
-
- state_filter = state_filter or StateFilter.all()
+ if state_filter is None:
+ state_filter = StateFilter.all()
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 875dba3349..f7a59c8992 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -284,7 +284,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
- state_filter = state_filter or StateFilter.all()
+ if state_filter is None:
+ state_filter = StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split()
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 67d1c3fe97..e641215f18 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -68,15 +68,23 @@ class StateFilter:
include_others: bool = False
def __attrs_post_init__(self) -> None:
- # If `include_others` is set we canonicalise the filter by removing
- # wildcards from the types dictionary
if self.include_others:
+ # If `include_others` is set we canonicalise the filter by removing
+ # wildcards from the types dictionary
+
# this is needed to work around the fact that StateFilter is frozen
object.__setattr__(
self,
"types",
immutabledict({k: v for k, v in self.types.items() if v is not None}),
)
+ else:
+ # Otherwise we remove entries where the value is the empty set.
+ object.__setattr__(
+ self,
+ "types",
+ immutabledict({k: v for k, v in self.types.items() if v is None or v}),
+ )
@staticmethod
def all() -> "StateFilter":
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 1960d2f0e1..9dd0e98971 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -1262,3 +1262,35 @@ class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase):
)
)
self.assertEqual(state[("m.test_event", "")], second_state["event_id"])
+
+ def test_incremental_sync_lazy_loaded_no_timeline(self) -> None:
+ """Test that lazy-loading with an empty timeline doesn't return the full
+ state.
+
+ There was a bug where an empty state filter would cause the DB to return
+ the full state, rather than an empty set.
+ """
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ since_token = self.hs.get_event_sources().get_current_token()
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=set(),
+ timeline_state={},
+ )
+ )
+
+ self.assertEqual(state, {})
|