summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthew Hodgson <matthew@matrix.org>2018-07-24 12:39:40 +0100
committerMatthew Hodgson <matthew@matrix.org>2018-07-24 12:39:40 +0100
commitcd241d6bda01a761fbe1ca29727dacd918fb8975 (patch)
treeb430c373ea65462a02237e3ca81ed9f56375e389
parenthandle case where types is [] on postgres correctly (diff)
downloadsynapse-cd241d6bda01a761fbe1ca29727dacd918fb8975.tar.xz
incorporate more review
-rw-r--r--synapse/handlers/sync.py12
-rw-r--r--synapse/storage/state.py36
-rw-r--r--tests/storage/test_state.py9
3 files changed, 27 insertions, 30 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5689ad2f58..e5a2329d73 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1526,6 +1526,9 @@ def _calculate_state(
         previous (dict): state at the end of the previous sync (or empty dict
             if this is an initial sync)
         current (dict): state at the end of the timeline
+        lazy_load_members (bool): whether to return members from timeline_start
+            or not.  assumes that timeline_start has already been filtered to
+            include only the members the client needs to know about.
 
     Returns:
         dict
@@ -1545,9 +1548,12 @@ def _calculate_state(
     p_ids = set(e for e in previous.values())
     tc_ids = set(e for e in timeline_contains.values())
 
-    # track the membership events in the state as of the start of the timeline
-    # so we can add them back in to the state if we're lazyloading.  We don't
-    # add them into state if they're already contained in the timeline.
+    # If we are lazyloading room members, we explicitly add the membership events
+    # for the senders in the timeline into the state block returned by /sync,
+    # as we may not have sent them to the client before.  We find these membership
+    # events by filtering them out of timeline_start, which has already been filtered
+    # to only include membership events for the senders in the timeline.
+
     if lazy_load_members:
         ll_ids = set(
             e for t, e in timeline_start.iteritems()
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f99d3871e4..1413a6f910 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -185,7 +185,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types, filtered_types=None):
+    def _get_state_groups_from_groups(self, groups, types):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
@@ -194,9 +194,6 @@ class StateGroupWorkerStore(SQLBaseStore):
             types (Iterable[str, str|None]|None): list of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
                 state_keys for the `type`. If None, all types are returned.
-            filtered_types(Iterable[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
 
         Returns:
             dictionary state_group -> (dict of (type, state_key) -> event id)
@@ -207,14 +204,14 @@ class StateGroupWorkerStore(SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types, filtered_types,
+                self._get_state_groups_from_groups_txn, chunk, types,
             )
             results.update(res)
 
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, types=None, filtered_types=None,
+        self, txn, groups, types=None,
     ):
         results = {group: {} for group in groups}
 
@@ -266,17 +263,6 @@ class StateGroupWorkerStore(SQLBaseStore):
                     )
                     for etype, state_key in types
                 ]
-
-                if filtered_types is not None:
-                    # XXX: check whether this slows postgres down like a list of
-                    # ORs does too?
-                    unique_types = set(filtered_types)
-                    clause_to_args.append(
-                        (
-                            "AND type <> ? " * len(unique_types),
-                            list(unique_types)
-                        )
-                    )
             else:
                 # If types is None we fetch all the state, and so just use an
                 # empty where clause with no extra args.
@@ -306,13 +292,6 @@ class StateGroupWorkerStore(SQLBaseStore):
                         where_clauses.append("(type = ? AND state_key = ?)")
                         where_args.extend([typ[0], typ[1]])
 
-                if filtered_types is not None:
-                    unique_types = set(filtered_types)
-                    where_clauses.append(
-                        "(" + " AND ".join(["type <> ?"] * len(unique_types)) + ")"
-                    )
-                    where_args.extend(list(unique_types))
-
                 where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
@@ -643,13 +622,13 @@ class StateGroupWorkerStore(SQLBaseStore):
             # cache. Hence, if we are doing a wildcard lookup, populate the
             # cache fully so that we can do an efficient lookup next time.
 
-            if types and any(k is None for (t, k) in types):
+            if filtered_types or (types and any(k is None for (t, k) in types)):
                 types_to_fetch = None
             else:
                 types_to_fetch = types
 
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types_to_fetch, filtered_types
+                missing_groups, types_to_fetch
             )
 
             for group, group_state_dict in iteritems(group_to_state_dict):
@@ -659,7 +638,10 @@ class StateGroupWorkerStore(SQLBaseStore):
                 if types:
                     for k, v in iteritems(group_state_dict):
                         (typ, _) = k
-                        if k in types or (typ, None) in types:
+                        if (
+                            (k in types or (typ, None) in types) or
+                            (filtered_types and typ not in filtered_types)
+                        ):
                             state_dict[k] = v
                 else:
                     state_dict.update(group_state_dict)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8924ba9f7f..b2f314e9db 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -158,3 +158,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
             (e2.type, e2.state_key): e2,
             (e3.type, e3.state_key): e3,
         }, state)
+
+        state = yield self.store.get_state_for_event(
+            e5.event_id, [], filtered_types=[EventTypes.Member],
+        )
+
+        self.assertStateMapEqual({
+            (e1.type, e1.state_key): e1,
+            (e2.type, e2.state_key): e2,
+        }, state)