summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py97
1 files changed, 43 insertions, 54 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 351892a94f..09739f2862 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.roommember import MemberSummary
+from synapse.storage.state import StateFilter
 from synapse.types import RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -469,25 +470,20 @@ class SyncHandler(object):
         ))
 
     @defer.inlineCallbacks
-    def get_state_after_event(self, event, types=None, filtered_types=None):
+    def get_state_after_event(self, event, state_filter=StateFilter.all()):
         """
         Get the room state after the given event
 
         Args:
             event(synapse.events.EventBase): event of interest
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
         state_ids = yield self.store.get_state_ids_for_event(
-            event.event_id, types, filtered_types=filtered_types,
+            event.event_id, state_filter=state_filter,
         )
         if event.is_state():
             state_ids = state_ids.copy()
@@ -495,18 +491,14 @@ class SyncHandler(object):
         defer.returnValue(state_ids)
 
     @defer.inlineCallbacks
-    def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
+    def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
         """ Get the room state at a particular stream position
 
         Args:
             room_id(str): room for which to get state
             stream_position(StreamToken): point at which to get state
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A Deferred map from ((type, state_key)->Event)
@@ -522,7 +514,7 @@ class SyncHandler(object):
         if last_events:
             last_event = last_events[-1]
             state = yield self.get_state_after_event(
-                last_event, types, filtered_types=filtered_types,
+                last_event, state_filter=state_filter,
             )
 
         else:
@@ -563,10 +555,11 @@ class SyncHandler(object):
 
         last_event = last_events[-1]
         state_ids = yield self.store.get_state_ids_for_event(
-            last_event.event_id, [
+            last_event.event_id,
+            state_filter=StateFilter.from_types([
                 (EventTypes.Name, ''),
                 (EventTypes.CanonicalAlias, ''),
-            ]
+            ]),
         )
 
         # this is heavily cached, thus: fast.
@@ -717,8 +710,7 @@ class SyncHandler(object):
 
         with Measure(self.clock, "compute_state_delta"):
 
-            types = None
-            filtered_types = None
+            members_to_fetch = None
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -729,16 +721,21 @@ class SyncHandler(object):
                 # We only request state for the members needed to display the
                 # timeline:
 
-                types = [
-                    (EventTypes.Member, state_key)
-                    for state_key in set(
-                        event.sender  # FIXME: we also care about invite targets etc.
-                        for event in batch.events
-                    )
-                ]
+                members_to_fetch = set(
+                    event.sender  # FIXME: we also care about invite targets etc.
+                    for event in batch.events
+                )
+
+                if full_state:
+                    # always make sure we LL ourselves so we know we're in the room
+                    # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
+                    # We only need apply this on full state syncs given we disabled
+                    # LL for incr syncs in #3840.
+                    members_to_fetch.add(sync_config.user.to_string())
 
-                # only apply the filtering to room members
-                filtered_types = [EventTypes.Member]
+                state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+            else:
+                state_filter = StateFilter.all()
 
             timeline_state = {
                 (event.type, event.state_key): event.event_id
@@ -746,28 +743,19 @@ class SyncHandler(object):
             }
 
             if full_state:
-                if lazy_load_members:
-                    # always make sure we LL ourselves so we know we're in the room
-                    # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
-                    # We only need apply this on full state syncs given we disabled
-                    # LL for incr syncs in #3840.
-                    types.append((EventTypes.Member, sync_config.user.to_string()))
-
                 if batch:
                     current_state_ids = yield self.store.get_state_ids_for_event(
-                        batch.events[-1].event_id, types=types,
-                        filtered_types=filtered_types,
+                        batch.events[-1].event_id, state_filter=state_filter,
                     )
 
                     state_ids = yield self.store.get_state_ids_for_event(
-                        batch.events[0].event_id, types=types,
-                        filtered_types=filtered_types,
+                        batch.events[0].event_id, state_filter=state_filter,
                     )
 
                 else:
                     current_state_ids = yield self.get_state_at(
-                        room_id, stream_position=now_token, types=types,
-                        filtered_types=filtered_types,
+                        room_id, stream_position=now_token,
+                        state_filter=state_filter,
                     )
 
                     state_ids = current_state_ids
@@ -781,8 +769,7 @@ class SyncHandler(object):
                 )
             elif batch.limited:
                 state_at_timeline_start = yield self.store.get_state_ids_for_event(
-                    batch.events[0].event_id, types=types,
-                    filtered_types=filtered_types,
+                    batch.events[0].event_id, state_filter=state_filter,
                 )
 
                 # for now, we disable LL for gappy syncs - see
@@ -797,17 +784,15 @@ class SyncHandler(object):
                 # members to just be ones which were timeline senders, which then ensures
                 # all of the rest get included in the state block (if we need to know
                 # about them).
-                types = None
-                filtered_types = None
+                state_filter = StateFilter.all()
 
                 state_at_previous_sync = yield self.get_state_at(
-                    room_id, stream_position=since_token, types=types,
-                    filtered_types=filtered_types,
+                    room_id, stream_position=since_token,
+                    state_filter=state_filter,
                 )
 
                 current_state_ids = yield self.store.get_state_ids_for_event(
-                    batch.events[-1].event_id, types=types,
-                    filtered_types=filtered_types,
+                    batch.events[-1].event_id, state_filter=state_filter,
                 )
 
                 state_ids = _calculate_state(
@@ -821,7 +806,7 @@ class SyncHandler(object):
             else:
                 state_ids = {}
                 if lazy_load_members:
-                    if types and batch.events:
+                    if members_to_fetch and batch.events:
                         # We're returning an incremental sync, with no
                         # "gap" since the previous sync, so normally there would be
                         # no state to return.
@@ -831,8 +816,12 @@ class SyncHandler(object):
                         # timeline here, and then dedupe any redundant ones below.
 
                         state_ids = yield self.store.get_state_ids_for_event(
-                            batch.events[0].event_id, types=types,
-                            filtered_types=None,  # we only want members!
+                            batch.events[0].event_id,
+                            # we only want members!
+                            state_filter=StateFilter.from_types(
+                                (EventTypes.Member, member)
+                                for member in members_to_fetch
+                            ),
                         )
 
             if lazy_load_members and not include_redundant_members: