summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py4
-rw-r--r--synapse/storage/profile.py4
-rw-r--r--synapse/storage/room.py58
-rw-r--r--synapse/storage/state.py158
4 files changed, 175 insertions, 49 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 025a7fb6d9..f39c8c8461 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
         }
 
         events_map = {ev.event_id: ev for ev, _ in events_context}
+        room_version = yield self.get_room_version(room_id)
+
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, state_groups, events_map, get_events
+            room_id, room_version, state_groups, events_map, get_events
         )
 
         defer.returnValue((res.state, None))
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 60295da254..88b50f33b5 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_from_remote_profile_cache",
         )
 
-
-class ProfileStore(ProfileWorkerStore):
     def create_profile(self, user_localpart):
         return self._simple_insert(
             table="profiles",
@@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore):
             desc="set_profile_avatar_url",
         )
 
+
+class ProfileStore(ProfileWorkerStore):
     def add_remote_profile_cache(self, user_id, displayname, avatar_url):
         """Ensure we are caching the remote user's profiles.
 
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 3378fc77d1..61013b8919 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
             desc="is_room_blocked",
         )
 
+    @cachedInlineCallbacks(max_entries=10000)
+    def get_ratelimit_for_user(self, user_id):
+        """Check if there are any overrides for ratelimiting for the given
+        user
+
+        Args:
+            user_id (str)
+
+        Returns:
+            RatelimitOverride if there is an override, else None. If the contents
+            of RatelimitOverride are None or 0 then ratelimitng has been
+            disabled for that user entirely.
+        """
+        row = yield self._simple_select_one(
+            table="ratelimit_override",
+            keyvalues={"user_id": user_id},
+            retcols=("messages_per_second", "burst_count"),
+            allow_none=True,
+            desc="get_ratelimit_for_user",
+        )
+
+        if row:
+            defer.returnValue(RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
+            ))
+        else:
+            defer.returnValue(None)
+
 
 class RoomStore(RoomWorkerStore, SearchStore):
 
@@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
-    @cachedInlineCallbacks(max_entries=10000)
-    def get_ratelimit_for_user(self, user_id):
-        """Check if there are any overrides for ratelimiting for the given
-        user
-
-        Args:
-            user_id (str)
-
-        Returns:
-            RatelimitOverride if there is an override, else None. If the contents
-            of RatelimitOverride are None or 0 then ratelimitng has been
-            disabled for that user entirely.
-        """
-        row = yield self._simple_select_one(
-            table="ratelimit_override",
-            keyvalues={"user_id": user_id},
-            retcols=("messages_per_second", "burst_count"),
-            allow_none=True,
-            desc="get_ratelimit_for_user",
-        )
-
-        if row:
-            defer.returnValue(RatelimitOverride(
-                messages_per_second=row["messages_per_second"],
-                burst_count=row["burst_count"],
-            ))
-        else:
-            defer.returnValue(None)
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         yield self._simple_insert(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dd03c4168b..4b971efdba 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     def __init__(self, db_conn, hs):
         super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
+        # Originally the state store used a single DictionaryCache to cache the
+        # event IDs for the state types in a given state group to avoid hammering
+        # on the state_group* tables.
+        #
+        # The point of using a DictionaryCache is that it can cache a subset
+        # of the state events for a given state group (i.e. a subset of the keys for a
+        # given dict which is an entry in the cache for a given state group ID).
+        #
+        # However, this poses problems when performing complicated queries
+        # on the store - for instance: "give me all the state for this group, but
+        # limit members to this subset of users", as DictionaryCache's API isn't
+        # rich enough to say "please cache any of these fields, apart from this subset".
+        # This is problematic when lazy loading members, which requires this behaviour,
+        # as without it the cache has no choice but to speculatively load all
+        # state events for the group, which negates the efficiency being sought.
+        #
+        # Rather than overcomplicating DictionaryCache's API, we instead split the
+        # state_group_cache into two halves - one for tracking non-member events,
+        # and the other for tracking member_events.  This means that lazy loading
+        # queries can be made in a cache-friendly manner by querying both caches
+        # separately and then merging the result.  So for the example above, you
+        # would query the members cache for a specific subset of state keys
+        # (which DictionaryCache will handle efficiently and fine) and the non-members
+        # cache for all state (which DictionaryCache will similarly handle fine)
+        # and then just merge the results together.
+        #
+        # We size the non-members cache to be smaller than the members cache as the
+        # vast majority of state in Matrix (today) is member events.
+
         self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
+            "*stateGroupCache*",
+            # TODO: this hasn't been tuned yet
+            50000 * get_cache_factor_for("stateGroupCache")
+        )
+        self._state_group_members_cache = DictionaryCache(
+            "*stateGroupMembersCache*",
+            500000 * get_cache_factor_for("stateGroupMembersCache")
         )
 
     @defer.inlineCallbacks
@@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types):
+    def _get_state_groups_from_groups(self, groups, types, members=None):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
@@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             types (Iterable[str, str|None]|None): list of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
                 state_keys for the `type`. If None, all types are returned.
+            members (bool|None): If not None, then, in addition to any filtering
+                implied by types, the results are also filtered to only include
+                member events (if True), or to exclude member events (if False)
 
         Returns:
             dictionary state_group -> (dict of (type, state_key) -> event id)
@@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types,
+                self._get_state_groups_from_groups_txn, chunk, types, members,
             )
             results.update(res)
 
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, types=None,
+        self, txn, groups, types=None, members=None,
     ):
         results = {group: {} for group in groups}
 
@@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 %s
             """)
 
+            if members is True:
+                sql += " AND type = '%s'" % (EventTypes.Member,)
+            elif members is False:
+                sql += " AND type <> '%s'" % (EventTypes.Member,)
+
             # Turns out that postgres doesn't like doing a list of OR's and
             # is about 1000x slower, so we just issue a query for each specific
             # type seperately.
@@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             else:
                 where_clause = ""
 
+            if members is True:
+                where_clause += " AND type = '%s'" % EventTypes.Member
+            elif members is False:
+                where_clause += " AND type <> '%s'" % EventTypes.Member
+
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
@@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
-    def _get_some_state_from_cache(self, group, types, filtered_types=None):
+    def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Args:
+            cache(DictionaryCache): the state group cache to use
             group(int): The state group to lookup
             types(list[str, str|None]): List of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
@@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         requests state from the cache, if False we need to query the DB for the
         missing state.
         """
-        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
+        is_all, known_absent, state_dict_ids = cache.get(group)
 
         type_to_key = {}
 
-        # tracks whether any of ourrequested types are missing from the cache
+        # tracks whether any of our requested types are missing from the cache
         missing_types = False
 
         for typ, state_key in types:
@@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             if include(k[0], k[1])
         }, got_all
 
-    def _get_all_state_from_cache(self, group):
+    def _get_all_state_from_cache(self, cache, group):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         cache, if False we need to query the DB for the missing state.
 
         Args:
+            cache(DictionaryCache): the state group cache to use
             group: The state group to lookup
         """
-        is_all, _, state_dict_ids = self._state_group_cache.get(group)
+        is_all, _, state_dict_ids = cache.get(group)
 
         return state_dict_ids, is_all
 
@@ -685,6 +735,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             Deferred[dict[int, dict[(type, state_key), EventBase]]]
                 a dictionary mapping from state group to state dictionary.
         """
+        if types is not None:
+            non_member_types = [t for t in types if t[0] != EventTypes.Member]
+
+            if filtered_types is not None and EventTypes.Member not in filtered_types:
+                # we want all of the membership events
+                member_types = None
+            else:
+                member_types = [t for t in types if t[0] == EventTypes.Member]
+
+        else:
+            non_member_types = None
+            member_types = None
+
+        non_member_state = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_cache, non_member_types, filtered_types,
+        )
+        # XXX: we could skip this entirely if member_types is []
+        member_state = yield self._get_state_for_groups_using_cache(
+            # we set filtered_types=None as member_state only ever contain members.
+            groups, self._state_group_members_cache, member_types, None,
+        )
+
+        state = non_member_state
+        for group in groups:
+            state[group].update(member_state[group])
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups_using_cache(
+        self, groups, cache, types=None, filtered_types=None
+    ):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key, querying from a specific cache.
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            cache (DictionaryCache): the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the specific
+                members state cache.
+            types (None|iterable[(str, None|str)]):
+                indicates the state type/keys required. If None, the whole
+                state is fetched and returned.
+
+                Otherwise, each entry should be a `(type, state_key)` tuple to
+                include in the response. A `state_key` of None is a wildcard
+                meaning that we require all state with that type.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
+
+        Returns:
+            Deferred[dict[int, dict[(type, state_key), EventBase]]]
+                a dictionary mapping from state group to state dictionary.
+        """
         if types:
             types = frozenset(types)
         results = {}
@@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         if types is not None:
             for group in set(groups):
                 state_dict_ids, got_all = self._get_some_state_from_cache(
-                    group, types, filtered_types
+                    cache, group, types, filtered_types
                 )
                 results[group] = state_dict_ids
 
@@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         else:
             for group in set(groups):
                 state_dict_ids, got_all = self._get_all_state_from_cache(
-                    group
+                    cache, group
                 )
 
                 results[group] = state_dict_ids
@@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     missing_groups.append(group)
 
         if missing_groups:
-            # Okay, so we have some missing_types, lets fetch them.
-            cache_seq_num = self._state_group_cache.sequence
+            # Okay, so we have some missing_types, let's fetch them.
+            cache_seq_num = cache.sequence
 
             # the DictionaryCache knows if it has *all* the state, but
             # does not know if it has all of the keys of a particular type,
@@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 types_to_fetch = types
 
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types_to_fetch
+                missing_groups, types_to_fetch, cache == self._state_group_members_cache,
             )
 
             for group, group_state_dict in iteritems(group_to_state_dict):
@@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
                 # update the cache with all the things we fetched from the
                 # database.
-                self._state_group_cache.update(
+                cache.update(
                     cache_seq_num,
                     key=group,
                     value=group_state_dict,
@@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     ],
                 )
 
-            # Prefill the state group cache with this group.
+            # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
             # is immutable. (If the map wasn't immutable then this prefill could
             # race with another update)
+
+            current_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] == EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_members_cache.update,
+                self._state_group_members_cache.sequence,
+                key=state_group,
+                value=dict(current_member_state_ids),
+            )
+
+            current_non_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] != EventTypes.Member
+            }
             txn.call_after(
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_state_ids),
+                value=dict(current_non_member_state_ids),
             )
 
             return state_group