diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py
index 8c12cea8eb..39dba4ff98 100644
--- a/synapse/handlers/sliding_sync/__init__.py
+++ b/synapse/handlers/sliding_sync/__init__.py
@@ -14,7 +14,7 @@
import logging
from itertools import chain
-from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple
+from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple
from prometheus_client import Histogram
from typing_extensions import assert_never
@@ -522,6 +522,8 @@ class SlidingSyncHandler:
state_reset_out_of_room = True
+ prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
+
# Determine whether we should limit the timeline to the token range.
#
# We should return historical messages (before token range) in the
@@ -550,7 +552,6 @@ class SlidingSyncHandler:
# or `limited` mean for clients that interpret them correctly. In future this
# behavior is almost certainly going to change.
#
- # TODO: Also handle changes to `required_state`
from_bound = None
initial = True
ignore_timeline_bound = False
@@ -571,7 +572,6 @@ class SlidingSyncHandler:
log_kv({"sliding_sync.room_status": room_status})
- prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
if prev_room_sync_config is not None:
# Check if the timeline limit has increased, if so ignore the
# timeline bound and record the change (see "XXX: Odd behavior"
@@ -582,8 +582,6 @@ class SlidingSyncHandler:
):
ignore_timeline_bound = True
- # TODO: Check for changes in `required_state``
-
log_kv(
{
"sliding_sync.from_bound": from_bound,
@@ -997,6 +995,10 @@ class SlidingSyncHandler:
include_others=required_state_filter.include_others,
)
+ # The required state map to store in the room sync config, if it has
+ # changed.
+ changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None
+
# We can return all of the state that was requested if this was the first
# time we've sent the room down this connection.
room_state: StateMap[EventBase] = {}
@@ -1010,6 +1012,29 @@ class SlidingSyncHandler:
else:
assert from_bound is not None
+ if prev_room_sync_config is not None:
+ # Check if there are any changes to the required state config
+ # that we need to handle.
+ changed_required_state_map, added_state_filter = (
+ _required_state_changes(
+ user.to_string(),
+ previous_room_config=prev_room_sync_config,
+ room_sync_config=room_sync_config,
+ state_deltas=room_state_delta_id_map,
+ )
+ )
+
+ if added_state_filter:
+ # Some state entries got added, so we pull out the current
+ # state for them. If we don't do this we'd only send down new deltas.
+ state_ids = await self.get_current_state_ids_at(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ state_filter=added_state_filter,
+ to_token=to_token,
+ )
+ room_state_delta_id_map.update(state_ids)
+
events = await self.store.get_events(
state_filter.filter_state(room_state_delta_id_map).values()
)
@@ -1108,10 +1133,13 @@ class SlidingSyncHandler:
# sensible order again.
bump_stamp = 0
- unstable_expanded_timeline = False
- prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
+ room_sync_required_state_map_to_persist = room_sync_config.required_state_map
+ if changed_required_state_map:
+ room_sync_required_state_map_to_persist = changed_required_state_map
+
# Record the `room_sync_config` if we're `ignore_timeline_bound` (which means
# that the `timeline_limit` has increased)
+ unstable_expanded_timeline = False
if ignore_timeline_bound:
# FIXME: We signal the fact that we're sending down more events to
# the client by setting `unstable_expanded_timeline` to true (see
@@ -1120,7 +1148,7 @@ class SlidingSyncHandler:
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
- required_state_map=room_sync_config.required_state_map,
+ required_state_map=room_sync_required_state_map_to_persist,
)
elif prev_room_sync_config is not None:
# If the result is `limited` then we need to record that the
@@ -1149,10 +1177,14 @@ class SlidingSyncHandler:
):
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
- required_state_map=room_sync_config.required_state_map,
+ required_state_map=room_sync_required_state_map_to_persist,
)
- # TODO: Record changes in required_state.
+ elif changed_required_state_map is not None:
+ new_connection_state.room_configs[room_id] = RoomSyncConfig(
+ timeline_limit=room_sync_config.timeline_limit,
+ required_state_map=room_sync_required_state_map_to_persist,
+ )
else:
new_connection_state.room_configs[room_id] = room_sync_config
@@ -1285,3 +1317,185 @@ class SlidingSyncHandler:
return new_bump_event_pos.stream
return None
+
+
+def _required_state_changes(
+ user_id: str,
+ *,
+ previous_room_config: "RoomSyncConfig",
+ room_sync_config: RoomSyncConfig,
+ state_deltas: StateMap[str],
+) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]:
+ """Calculates the changes between the required state room config from the
+ previous requests compared with the current request.
+
+ This does two things. First, it calculates if we need to update the room
+ config due to changes to required state. Secondly, it works out which state
+ entries we need to pull from current state and return due to the state entry
+ now appearing in the required state when it previously wasn't (on top of the
+ state deltas).
+
+ This function tries to ensure to handle the case where a state entry is
+ added, removed and then added again to the required state. In that case we
+ only want to re-send that entry down sync if it has changed.
+
+ Returns:
+ A 2-tuple of updated required state config (or None if there is no update)
+ and the state filter to use to fetch extra current state that we need to
+ return.
+ """
+
+ prev_required_state_map = previous_room_config.required_state_map
+ request_required_state_map = room_sync_config.required_state_map
+
+ if prev_required_state_map == request_required_state_map:
+ # There has been no change. Return immediately.
+ return None, StateFilter.none()
+
+ prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set())
+ request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set())
+
+ # If we were previously fetching everything ("*", "*"), always update the effective
+ # room required state config to match the request. And since we we're previously
+ # already fetching everything, we don't have to fetch anything now that they've
+ # narrowed.
+ if StateValues.WILDCARD in prev_wildcard:
+ return request_required_state_map, StateFilter.none()
+
+ # If a event type wildcard has been added or removed we don't try and do
+ # anything fancy, and instead always update the effective room required
+ # state config to match the request.
+ if request_wildcard - prev_wildcard:
+ # Some keys were added, so we need to fetch everything
+ return request_required_state_map, StateFilter.all()
+ if prev_wildcard - request_wildcard:
+ # Keys were only removed, so we don't have to fetch everything.
+ return request_required_state_map, StateFilter.none()
+
+ # Contains updates to the required state map compared with the previous room
+ # config. This has the same format as `RoomSyncConfig.required_state`
+ changes: Dict[str, AbstractSet[str]] = {}
+
+ # The set of types/state keys that we need to fetch and return to the
+ # client. Passed to `StateFilter.from_types(...)`
+ added: List[Tuple[str, Optional[str]]] = []
+
+ # First we calculate what, if anything, has been *added*.
+ for event_type in (
+ prev_required_state_map.keys() | request_required_state_map.keys()
+ ):
+ old_state_keys = prev_required_state_map.get(event_type, set())
+ request_state_keys = request_required_state_map.get(event_type, set())
+
+ if old_state_keys == request_state_keys:
+ # No change to this type
+ continue
+
+ if not request_state_keys - old_state_keys:
+ # Nothing *added*, so we skip. Removals happen below.
+ continue
+
+ # Always update changes to include the newly added keys
+ changes[event_type] = request_state_keys
+
+ if StateValues.WILDCARD in old_state_keys:
+ # We were previously fetching everything for this type, so we don't need to
+ # fetch anything new.
+ continue
+
+ # Record the new state keys to fetch for this type.
+ if StateValues.WILDCARD in request_state_keys:
+ # If we have added a wildcard then we always just fetch everything.
+ added.append((event_type, None))
+ else:
+ for state_key in request_state_keys - old_state_keys:
+ if state_key == StateValues.ME:
+ added.append((event_type, user_id))
+ elif state_key == StateValues.LAZY:
+ # We handle lazy loading separately (outside this function),
+ # so don't need to explicitly add anything here.
+ #
+ # LAZY values should also be ignore for event types that are
+ # not membership.
+ pass
+ else:
+ added.append((event_type, state_key))
+
+ added_state_filter = StateFilter.from_types(added)
+
+ # Convert the list of state deltas to map from type to state_keys that have
+ # changed.
+ changed_types_to_state_keys: Dict[str, Set[str]] = {}
+ for event_type, state_key in state_deltas:
+ changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)
+
+ # Figure out what changes we need to apply to the effective required state
+ # config.
+ for event_type, changed_state_keys in changed_types_to_state_keys.items():
+ old_state_keys = prev_required_state_map.get(event_type, set())
+ request_state_keys = request_required_state_map.get(event_type, set())
+
+ if old_state_keys == request_state_keys:
+ # No change.
+ continue
+
+ if request_state_keys - old_state_keys:
+ # We've expanded the set of state keys, so we just clobber the
+ # current set with the new set.
+ #
+ # We could also ensure that we keep entries where the state hasn't
+ # changed, but are no longer in the requested required state, but
+ # that's a sufficient edge case that we can ignore (as its only a
+ # performance optimization).
+ changes[event_type] = request_state_keys
+ continue
+
+ old_state_key_wildcard = StateValues.WILDCARD in old_state_keys
+ request_state_key_wildcard = StateValues.WILDCARD in request_state_keys
+
+ if old_state_key_wildcard != request_state_key_wildcard:
+ # If a state_key wildcard has been added or removed, we always update the
+ # effective room required state config to match the request.
+ changes[event_type] = request_state_keys
+ continue
+
+ if event_type == EventTypes.Member:
+ old_state_key_lazy = StateValues.LAZY in old_state_keys
+ request_state_key_lazy = StateValues.LAZY in request_state_keys
+
+ if old_state_key_lazy != request_state_key_lazy:
+ # If a "$LAZY" has been added or removed we always update the effective room
+ # required state config to match the request.
+ changes[event_type] = request_state_keys
+ continue
+
+ # Handle "$ME" values by adding "$ME" if the state key matches the user
+ # ID.
+ if user_id in changed_state_keys:
+ changed_state_keys.add(StateValues.ME)
+
+ # At this point there are no wildcards and no additions to the set of
+ # state keys requested, only deletions.
+ #
+ # We only remove state keys from the effective state if they've been
+ # removed from the request *and* the state has changed. This ensures
+ # that if a client removes and then re-adds a state key, we only send
+ # down the associated current state event if its changed (rather than
+ # sending down the same event twice).
+ invalidated = (old_state_keys - request_state_keys) & changed_state_keys
+ if invalidated:
+ changes[event_type] = old_state_keys - invalidated
+
+ if changes:
+ # Update the required state config based on the changes.
+ new_required_state_map = dict(prev_required_state_map)
+ for event_type, state_keys in changes.items():
+ if state_keys:
+ new_required_state_map[event_type] = state_keys
+ else:
+ # Remove entries with empty state keys.
+ new_required_state_map.pop(event_type, None)
+
+ return new_required_state_map, added_state_filter
+ else:
+ return None, added_state_filter
diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py
index f2df37fec1..7b357c1ffe 100644
--- a/synapse/storage/databases/main/sliding_sync.py
+++ b/synapse/storage/databases/main/sliding_sync.py
@@ -386,8 +386,8 @@ class SlidingSyncStore(SQLBaseStore):
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
for row in rows:
state = required_state_map[row[0]] = {}
- for event_type, state_keys in db_to_json(row[1]):
- state[event_type] = set(state_keys)
+ for event_type, state_key in db_to_json(row[1]):
+ state.setdefault(event_type, set()).add(state_key)
# Get all the room configs, looking up the required state from the map
# above.
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 1141c4b5c1..67d1c3fe97 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -616,6 +616,13 @@ class StateFilter:
return False
+ def __bool__(self) -> bool:
+ """Returns true if this state filter will match any state, or false if
+ this is the empty filter"""
+ if self.include_others:
+ return True
+ return bool(self.types)
+
_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|