summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-04-17 19:44:40 +0100
committerErik Johnston <erik@matrix.org>2019-04-17 19:44:40 +0100
commitca90336a6935b36b5761244005b0f68b496d5d79 (patch)
tree6bbce5eafc0db3b24ccc3b59b051da850382ae09 /synapse/storage/state.py
parentAdd management endpoints for account validity (diff)
parentMerge pull request #5047 from matrix-org/babolivier/account_expiration (diff)
downloadsynapse-ca90336a6935b36b5761244005b0f68b496d5d79.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/account_expiration
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py237
1 files changed, 102 insertions, 135 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6ddc4055d2..0bfe1b4550 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
-class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+class _GetStateGroupDelta(
+    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
     """Return type of get_state_group_delta that implements __len__, which lets
     us use the itrable flag when caching
     """
+
     __slots__ = []
 
     def __len__(self):
@@ -70,10 +73,7 @@ class StateFilter(object):
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
-            self.types = {
-                k: v for k, v in iteritems(self.types)
-                if v is not None
-            }
+            self.types = {k: v for k, v in iteritems(self.types) if v is not None}
 
     @staticmethod
     def all():
@@ -130,10 +130,7 @@ class StateFilter(object):
         Returns:
             StateFilter
         """
-        return StateFilter(
-            types={EventTypes.Member: set(members)},
-            include_others=True,
-        )
+        return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
 
     def return_expanded(self):
         """Creates a new StateFilter where type wild cards have been removed
@@ -243,9 +240,7 @@ class StateFilter(object):
             if where_clause:
                 where_clause += " OR "
 
-            where_clause += "type NOT IN (%s)" % (
-                ",".join(["?"] * len(self.types)),
-            )
+            where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
             where_args.extend(self.types)
 
         return where_clause, where_args
@@ -305,12 +300,8 @@ class StateFilter(object):
             bool
         """
 
-        return (
-            self.include_others
-            or any(
-                state_keys is None
-                for state_keys in itervalues(self.types)
-            )
+        return self.include_others or any(
+            state_keys is None for state_keys in itervalues(self.types)
         )
 
     def concrete_types(self):
@@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         self._state_group_cache = DictionaryCache(
             "*stateGroupCache*",
             # TODO: this hasn't been tuned yet
-            50000 * get_cache_factor_for("stateGroupCache")
+            50000 * get_cache_factor_for("stateGroupCache"),
         )
         self._state_group_members_cache = DictionaryCache(
             "*stateGroupMembersCache*",
-            500000 * get_cache_factor_for("stateGroupMembersCache")
+            500000 * get_cache_factor_for("stateGroupMembersCache"),
         )
 
     @defer.inlineCallbacks
@@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             deferred: dict of (type, state_key) -> event_id
         """
+
         def _get_current_state_ids_txn(txn):
             txn.execute(
                 """SELECT type, state_key, event_id FROM current_state_events
                 WHERE room_id = ?
                 """,
-                (room_id,)
+                (room_id,),
             )
 
             return {
                 (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
             }
 
-        return self.runInteraction(
-            "get_current_state_ids",
-            _get_current_state_ids_txn,
-        )
+        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
 
     # FIXME: how should this be cached?
     def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             return results
 
         return self.runInteraction(
-            "get_filtered_current_state_ids",
-            _get_filtered_current_state_ids_txn,
+            "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
     @defer.inlineCallbacks
@@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             Deferred[str|None]: The canonical alias, if any
         """
 
-        state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types(
-            [(EventTypes.CanonicalAlias, "")]
-        ))
+        state = yield self.get_filtered_current_state_ids(
+            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+        )
 
         event_id = state.get((EventTypes.CanonicalAlias, ""))
         if not event_id:
@@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             (prev_group, delta_ids), where both may be None.
         """
+
         def _get_state_group_delta_txn(txn):
             prev_group = self._simple_select_one_onecol_txn(
                 txn,
                 table="state_group_edges",
-                keyvalues={
-                    "state_group": state_group,
-                },
+                keyvalues={"state_group": state_group},
                 retcol="prev_state_group",
                 allow_none=True,
             )
@@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             delta_ids = self._simple_select_list_txn(
                 txn,
                 table="state_groups_state",
-                keyvalues={
-                    "state_group": state_group,
-                },
-                retcols=("type", "state_key", "event_id",)
+                keyvalues={"state_group": state_group},
+                retcols=("type", "state_key", "event_id"),
             )
 
-            return _GetStateGroupDelta(prev_group, {
-                (row["type"], row["state_key"]): row["event_id"]
-                for row in delta_ids
-            })
-        return self.runInteraction(
-            "get_state_group_delta",
-            _get_state_group_delta_txn,
-        )
+            return _GetStateGroupDelta(
+                prev_group,
+                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+            )
+
+        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, _room_id, event_ids):
@@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         if not event_ids:
             defer.returnValue({})
 
-        event_to_groups = yield self._get_state_group_for_events(
-            event_ids,
-        )
+        event_to_groups = yield self._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups)
@@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         state_event_map = yield self.get_events(
             [
-                ev_id for group_ids in itervalues(group_to_ids)
+                ev_id
+                for group_ids in itervalues(group_to_ids)
                 for ev_id in itervalues(group_ids)
             ],
-            get_prev_content=False
+            get_prev_content=False,
         )
 
-        defer.returnValue({
-            group: [
-                state_event_map[v] for v in itervalues(event_id_map)
-                if v in state_event_map
-            ]
-            for group, event_id_map in iteritems(group_to_ids)
-        })
+        defer.returnValue(
+            {
+                group: [
+                    state_event_map[v]
+                    for v in itervalues(event_id_map)
+                    if v in state_event_map
+                ]
+                for group, event_id_map in iteritems(group_to_ids)
+            }
+        )
 
     @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, state_filter):
@@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         results = {}
 
-        chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
+        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, state_filter,
+                self._get_state_groups_from_groups_txn,
+                chunk,
+                state_filter,
             )
             results.update(res)
 
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter=StateFilter.all(),
+        self, txn, groups, state_filter=StateFilter.all()
     ):
         results = {group: {} for group in groups}
 
@@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
                         " WHERE state_group = ? " + where_clause,
-                        args
+                        args,
                     )
                     results[group].update(
                         ((typ, state_key), event_id)
@@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     # wildcards (i.e. Nones) in which case we have to do an exhaustive
                     # search
                     if (
-                        max_entries_returned is not None and
-                        len(results[group]) == max_entries_returned
+                        max_entries_returned is not None
+                        and len(results[group]) == max_entries_returned
                     ):
                         break
 
@@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
         """
-        event_to_groups = yield self._get_state_group_for_events(
-            event_ids,
-        )
+        event_to_groups = yield self._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups, state_filter)
 
         state_event_map = yield self.get_events(
             [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
-            get_prev_content=False
+            get_prev_content=False,
         )
 
         event_to_state = {
@@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred dict from event_id -> (type, state_key) -> event_id
         """
-        event_to_groups = yield self._get_state_group_for_events(
-            event_ids,
-        )
+        event_to_groups = yield self._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups, state_filter)
@@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     def _get_state_group_for_event(self, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
-            keyvalues={
-                "event_id": event_id,
-            },
+            keyvalues={"event_id": event_id},
             retcol="state_group",
             allow_none=True,
             desc="_get_state_group_for_event",
         )
 
-    @cachedList(cached_method_name="_get_state_group_for_event",
-                list_name="event_ids", num_args=1, inlineCallbacks=True)
+    @cachedList(
+        cached_method_name="_get_state_group_for_event",
+        list_name="event_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
     def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
@@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             column="event_id",
             iterable=event_ids,
             keyvalues={},
-            retcols=("event_id", "state_group",),
+            retcols=("event_id", "state_group"),
             desc="_get_state_group_for_events",
         )
 
@@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         # Now we look them up in the member and non-member caches
         non_member_state, incomplete_groups_nm, = (
             yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_cache,
-                state_filter=non_member_filter,
+                groups, self._state_group_cache, state_filter=non_member_filter
             )
         )
 
         member_state, incomplete_groups_m, = (
             yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_members_cache,
-                state_filter=member_filter,
+                groups, self._state_group_members_cache, state_filter=member_filter
             )
         )
 
@@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         db_state_filter = state_filter.return_expanded()
 
         group_to_state_dict = yield self._get_state_groups_from_groups(
-            list(incomplete_groups),
-            state_filter=db_state_filter,
+            list(incomplete_groups), state_filter=db_state_filter
         )
 
         # Now lets update the caches
@@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         defer.returnValue(state)
 
-    def _get_state_for_groups_using_cache(
-        self, groups, cache, state_filter,
-    ):
+    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
@@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return results, incomplete_groups
 
-    def _insert_into_cache(self, group_to_state_dict, state_filter,
-                           cache_seq_num_members, cache_seq_num_non_members):
+    def _insert_into_cache(
+        self,
+        group_to_state_dict,
+        state_filter,
+        cache_seq_num_members,
+        cache_seq_num_non_members,
+    ):
         """Inserts results from querying the database into the relevant cache.
 
         Args:
@@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
-    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
-                          current_state_ids):
+    def store_state_group(
+        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+    ):
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
@@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             Deferred[int]: The state group ID
         """
+
         def _store_state_group_txn(txn):
             if current_state_ids is None:
                 # AFAIK, this can never happen
@@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
-                values={
-                    "id": state_group,
-                    "room_id": room_id,
-                    "event_id": event_id,
-                },
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
             )
 
             # We persist as a delta if we can, while also ensuring the chain
@@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                         % (prev_group,)
                     )
 
-                potential_hops = self._count_state_group_hops_txn(
-                    txn, prev_group
-                )
+                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
             if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                 self._simple_insert_txn(
                     txn,
                     table="state_group_edges",
-                    values={
-                        "state_group": state_group,
-                        "prev_state_group": prev_group,
-                    },
+                    values={"state_group": state_group, "prev_state_group": prev_group},
                 )
 
                 self._simple_insert_many_txn(
@@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         This is used to ensure the delta chains don't get too long.
         """
         if isinstance(self.database_engine, PostgresEngine):
-            sql = ("""
+            sql = """
                 WITH RECURSIVE state(state_group) AS (
                     VALUES(?::bigint)
                     UNION ALL
@@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     WHERE s.state_group = e.state_group
                 )
                 SELECT count(*) FROM state;
-            """)
+            """
 
             txn.execute(sql, (state_group,))
             row = txn.fetchone()
@@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
             self._background_deduplicate_state,
         )
         self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME,
-            self._background_index_state,
+            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
         )
         self.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
             txn,
             table="event_to_state_groups",
             values=[
-                {
-                    "state_group": state_group_id,
-                    "event_id": event_id,
-                }
+                {"state_group": state_group_id, "event_id": event_id}
                 for event_id, state_group_id in iteritems(state_groups)
             ],
         )
 
         for event_id, state_group_id in iteritems(state_groups):
             txn.call_after(
-                self._get_state_group_for_event.prefill,
-                (event_id,), state_group_id
+                self._get_state_group_for_event.prefill, (event_id,), state_group_id
             )
 
     @defer.inlineCallbacks
@@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
 
         if max_group is None:
             rows = yield self._execute(
-                "_background_deduplicate_state", None,
+                "_background_deduplicate_state",
+                None,
                 "SELECT coalesce(max(id), 0) FROM state_groups",
             )
             max_group = rows[0][0]
@@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                     " WHERE ? < id AND id <= ?"
                     " ORDER BY id ASC"
                     " LIMIT 1",
-                    (new_last_state_group, max_group,)
+                    (new_last_state_group, max_group),
                 )
                 row = txn.fetchone()
                 if row:
@@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                 txn.execute(
                     "SELECT state_group FROM state_group_edges"
                     " WHERE state_group = ?",
-                    (state_group,)
+                    (state_group,),
                 )
 
                 # If we reach a point where we've already started inserting
@@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                 txn.execute(
                     "SELECT coalesce(max(id), 0) FROM state_groups"
                     " WHERE id < ? AND room_id = ?",
-                    (state_group, room_id,)
+                    (state_group, room_id),
                 )
                 prev_group, = txn.fetchone()
                 new_last_state_group = state_group
 
                 if prev_group:
-                    potential_hops = self._count_state_group_hops_txn(
-                        txn, prev_group
-                    )
+                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
                     if potential_hops >= MAX_STATE_DELTA_HOPS:
                         # We want to ensure chains are at most this long,#
                         # otherwise read performance degrades.
                         continue
 
                     prev_state = self._get_state_groups_from_groups_txn(
-                        txn, [prev_group],
+                        txn, [prev_group]
                     )
                     prev_state = prev_state[prev_group]
 
                     curr_state = self._get_state_groups_from_groups_txn(
-                        txn, [state_group],
+                        txn, [state_group]
                     )
                     curr_state = curr_state[state_group]
 
@@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                         # of keys
 
                         delta_state = {
-                            key: value for key, value in iteritems(curr_state)
+                            key: value
+                            for key, value in iteritems(curr_state)
                             if prev_state.get(key, None) != value
                         }
 
                         self._simple_delete_txn(
                             txn,
                             table="state_group_edges",
-                            keyvalues={
-                                "state_group": state_group,
-                            }
+                            keyvalues={"state_group": state_group},
                         )
 
                         self._simple_insert_txn(
@@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                             values={
                                 "state_group": state_group,
                                 "prev_state_group": prev_group,
-                            }
+                            },
                         )
 
                         self._simple_delete_txn(
                             txn,
                             table="state_groups_state",
-                            keyvalues={
-                                "state_group": state_group,
-                            }
+                            keyvalues={"state_group": state_group},
                         )
 
                         self._simple_insert_many_txn(
@@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
         )
 
         if finished:
-            yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
+            yield self._end_background_update(
+                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+            )
 
         defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
 
@@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                         "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
                         " ON state_groups_state(state_group, type, state_key)"
                     )
-                    txn.execute(
-                        "DROP INDEX IF EXISTS state_groups_state_id"
-                    )
+                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
                 finally:
                     conn.set_session(autocommit=False)
             else:
@@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                     "CREATE INDEX state_groups_state_type_idx"
                     " ON state_groups_state(state_group, type, state_key)"
                 )
-                txn.execute(
-                    "DROP INDEX IF EXISTS state_groups_state_id"
-                )
+                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
 
         yield self.runWithConnection(reindex_txn)