summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/devices.py10
-rw-r--r--synapse/storage/databases/main/events.py27
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/presence.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py13
-rw-r--r--synapse/storage/databases/main/registration.py21
-rw-r--r--synapse/storage/databases/main/relations.py26
-rw-r--r--synapse/storage/databases/main/room.py4
-rw-r--r--synapse/storage/databases/main/roommember.py62
-rw-r--r--synapse/storage/databases/main/search.py17
-rw-r--r--synapse/storage/databases/main/user_directory.py22
-rw-r--r--synapse/storage/databases/state/store.py203
-rw-r--r--synapse/storage/state.py8
13 files changed, 388 insertions, 90 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8d845fe951..3b3a089b76 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
+    def get_cached_device_list_changes(
+        self,
+        from_key: int,
+    ) -> Optional[Set[str]]:
+        """Get set of users whose devices have changed since `from_key`, or None
+        if that information is not in our cache.
+        """
+
+        return self._device_list_stream_cache.get_all_entities_changed(from_key)
+
     async def get_users_whose_devices_changed(
         self, from_key: int, user_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5246fccad5..a1d7a9b413 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -975,6 +975,17 @@ class PersistEventsStore:
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
 
+            # Figure out the changes of membership to invalidate the
+            # `get_rooms_for_user` cache.
+            # We find out which membership events we may have deleted
+            # and which we have added, then we invalidate the caches for all
+            # those users.
+            members_changed = {
+                state_key
+                for ev_type, state_key in itertools.chain(to_delete, to_insert)
+                if ev_type == EventTypes.Member
+            }
+
             if delta_state.no_longer_in_room:
                 # Server is no longer in the room so we delete the room from
                 # current_state_events, being careful we've already updated the
@@ -993,6 +1004,11 @@ class PersistEventsStore:
                 """
                 txn.execute(sql, (stream_id, self._instance_name, room_id))
 
+                # We also want to invalidate the membership caches for users
+                # that were in the room.
+                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+                members_changed.update(users_in_room)
+
                 self.db_pool.simple_delete_txn(
                     txn,
                     table="current_state_events",
@@ -1102,17 +1118,6 @@ class PersistEventsStore:
 
             # Invalidate the various caches
 
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
-
             for member in members_changed:
                 txn.call_after(
                     self.store.get_rooms_for_user_with_stream_ordering.invalidate,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8d4287045a..2a255d1031 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
                 include the previous states content in the unsigned field.
 
             allow_rejected: If True, return rejected events. Otherwise,
-                omits rejeted events from the response.
+                omits rejected events from the response.
 
         Returns:
             A mapping from event_id to event.
@@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore):
             forward_edge_query = """
                 SELECT 1 FROM event_edges
                 /* Check to make sure the event referencing our event in question is not rejected */
-                LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+                LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
                 WHERE
                     event_edges.room_id = ?
                     AND event_edges.prev_event_id = ?
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a77..d3c4611686 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         # Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+        self._presence_id_gen: AbstractStreamIdGenerator
+
         self._can_persist_presence = (
-            hs.get_instance_name() in hs.config.worker.writers.presence
+            self._instance_name in hs.config.worker.writers.presence
         )
 
         if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+    def _update_presence_txn(
+        self, txn: LoggingTransaction, stream_orderings, presence_states
+    ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
                 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_presence_updates_txn(txn):
+        def get_all_presence_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, user_id, state, last_active_ts,
                     last_federation_update_ts, last_user_sync_ts,
-                    status_msg,
-                currently_active
+                    status_msg, currently_active
                 FROM presence_stream
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(row[0], row[1:]) for row in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(row[0], row[1:]) for row in txn],
+            )
 
             upper_bound = current_id
             limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         )
 
     @cached()
-    def _get_presence_for_user(self, user_id):
+    def _get_presence_for_user(self, user_id: str) -> None:
         raise NotImplementedError()
 
     @cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         list_name="user_ids",
         num_args=1,
     )
-    async def get_presence_for_users(self, user_ids):
+    async def get_presence_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
         rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             True if the user should have full presence sent to them, False otherwise.
         """
 
-        def _should_user_receive_full_presence_with_token_txn(txn):
+        def _should_user_receive_full_presence_with_token_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT 1 FROM users_to_send_full_presence_to
                 WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             _should_user_receive_full_presence_with_token_txn,
         )
 
-    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
         """Adds to the list of users who should receive a full snapshot of presence
         upon their next sync.
 
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return users_to_state
 
-    def get_current_presence_token(self):
+    def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
-    def _get_active_presence(self, db_conn: Connection):
+    def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
         """
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return [UserPresenceState(**row) for row in rows]
 
-    def take_presence_startup_info(self):
+    def take_presence_startup_info(self) -> List[UserPresenceState]:
         active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
+        self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index e87a8fb85d..2e3818e432 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -13,9 +13,10 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Set, Tuple
+from typing import Any, List, Set, Tuple, cast
 
 from synapse.api.errors import SynapseError
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
 
     def _purge_history_txn(
-        self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        token: RoomStreamToken,
+        delete_local_events: bool,
     ) -> Set[int]:
         # Tables that should be pruned:
         #     event_auth
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         """,
             (room_id,),
         )
-        (min_depth,) = txn.fetchone()
+        (min_depth,) = cast(Tuple[int], txn.fetchone())
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "purge_room", self._purge_room_txn, room_id
         )
 
-    def _purge_room_txn(self, txn, room_id: str) -> List[int]:
+    def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
         # First we fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index aac94fa464..17110bb033 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ) -> None:
         """Record a mapping from an external user id to a mxid
 
+        See notes in _record_user_external_id_txn about what constitutes valid data.
+
         Args:
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
             user_id: complete mxid that it is mapped to
+
         Raises:
             ExternalIDReuseException if the new external_id could not be mapped.
         """
@@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         external_id: str,
         user_id: str,
     ) -> None:
+        """
+        Record a mapping from an external user id to a mxid.
+
+        Note that the auth provider IDs (and the external IDs) are not validated
+        against configured IdPs as Synapse does not know its relationship to
+        external systems. For example, it might be useful to pre-configure users
+        before enabling a new IdP or an IdP might be temporarily offline, but
+        still valid.
+
+        Args:
+            txn: The database transaction.
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
 
         self.db_pool.simple_insert_txn(
             txn,
@@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Replace mappings from external user ids to a mxid in a single transaction.
         All mappings are deleted and the new ones are created.
 
+        See notes in _record_user_external_id_txn about what constitutes valid data.
+
         Args:
             record_external_ids:
                 List with tuple of auth_provider and external_id to record
             user_id: complete mxid that it is mapped to
+
         Raises:
             ExternalIDReuseException if the new external_id could not be mapped.
         """
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e2c27e594b..36aa1092f6 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
+    # The latest event in the thread.
     latest_event: EventBase
+    # The latest edit to the latest event in the thread.
+    latest_edit: Optional[EventBase]
+    # The total number of events in the thread.
     count: int
+    # True if the current user has sent an event to the thread.
     current_user_participated: bool
 
 
@@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
     async def _get_thread_summaries(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
-        """Get the number of threaded replies and the latest reply (if any) for the given event.
+    ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
+        """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
 
         Args:
             event_ids: Summarize the thread related to this event ID.
@@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of the thread summary each event. A missing event implies there
             are no threaded replies.
 
-            Each summary includes the number of items in the thread and the most
-            recent response.
+            Each summary is a tuple of:
+                The number of events in the thread.
+                The most recent event in the thread.
+                The most recent edit to the most recent event in the thread, if applicable.
         """
 
         def _get_thread_summaries_txn(
@@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore):
             # TODO Should this only allow m.room.message events.
             if isinstance(self.database_engine, PostgresEngine):
                 # The `DISTINCT ON` clause will pick the *first* row it encounters,
-                # so ordering by topologica ordering + stream ordering desc will
+                # so ordering by topological ordering + stream ordering desc will
                 # ensure we get the latest event in the thread.
                 sql = """
                     SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
@@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
+        # Check to see if any of those events are edited.
+        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+
         # Map to the event IDs to the thread summary.
         #
         # There might not be a summary due to there not being a thread or
@@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
 
             summary = None
             if latest_event:
-                summary = (counts[parent_event_id], latest_event)
+                latest_edit = latest_edits.get(latest_event_id)
+                summary = (counts[parent_event_id], latest_event, latest_edit)
             summaries[parent_event_id] = summary
 
         return summaries
@@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
             )
             for event_id, summary in summaries.items():
                 if summary:
-                    thread_count, latest_thread_event = summary
+                    thread_count, latest_thread_event, edit = summary
                     results.setdefault(
                         event_id, BundledAggregations()
                     ).thread = _ThreadAggregation(
                         latest_event=latest_thread_event,
+                        latest_edit=edit,
                         count=thread_count,
                         # If there's a thread summary it must also exist in the
                         # participated dictionary.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 95167116c9..0416df64ce 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
 
     async def upsert_room_on_join(
-        self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+        self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
     ) -> None:
         """Ensure that the room is stored in the table
 
@@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         has_auth_chain_index = await self.has_auth_chain_index(room_id)
 
         create_event = None
-        for e in auth_events:
+        for e in state_events:
             if (e.type, e.state_key) == (EventTypes.Create, ""):
                 create_event = e
                 break
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4489732fda..e48ec5f495 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             for room_id, instance, stream_id in txn
         )
 
+    @cachedList(
+        cached_method_name="get_rooms_for_user_with_stream_ordering",
+        list_name="user_ids",
+    )
+    async def get_rooms_for_users_with_stream_ordering(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+        """A batched version of `get_rooms_for_user_with_stream_ordering`.
+
+        Returns:
+            Map from user_id to set of rooms that is currently in.
+        """
+        return await self.db_pool.runInteraction(
+            "get_rooms_for_users_with_stream_ordering",
+            self._get_rooms_for_users_with_stream_ordering_txn,
+            user_ids,
+        )
+
+    def _get_rooms_for_users_with_stream_ordering_txn(
+        self, txn, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "c.state_key",
+            user_ids,
+        )
+
+        if self._current_state_events_membership_up_to_date:
+            sql = f"""
+                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND c.membership = ?
+                    AND {clause}
+            """
+        else:
+            sql = f"""
+                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS m USING (room_id, event_id)
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND m.membership = ?
+                    AND {clause}
+            """
+
+        txn.execute(sql, [Membership.JOIN] + args)
+
+        result = {user_id: set() for user_id in user_ids}
+        for user_id, room_id, instance, stream_id in txn:
+            result[user_id].add(
+                GetRoomsForUserWithStreamOrdering(
+                    room_id, PersistedEventPosition(instance, stream_id)
+                )
+            )
+
+        return {user_id: frozenset(v) for user_id, v in result.items()}
+
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a5764..acea300ed3 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
     ):
         super().__init__(database, db_conn, hs)
 
-    async def search_msgs(self, room_ids, search_term, keys):
+    async def search_msgs(
+        self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
-            room_ids (list): List of room ids to search in
-            search_term (str): Search term to search for
-            keys (list): List of keys to search in, currently supports
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            list of dicts
+            Dictionary of results
         """
         clauses = []
 
@@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         self,
         room_ids: Collection[str],
         search_term: str,
-        keys: List[str],
+        keys: Iterable[str],
         limit,
         pagination_token: Optional[str] = None,
-    ) -> List[dict]:
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f7c778bdf2..e7fddd2426 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         processed_event_count = 0
 
         for room_id, event_count in rooms_to_work_on:
-            is_in_room = await self.is_host_joined(room_id, self.server_name)
+            is_in_room = await self.is_host_joined(room_id, self.server_name)  # type: ignore[attr-defined]
 
             if is_in_room:
-                users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+                users_with_profile = await self.get_users_in_room_with_profiles(room_id)  # type: ignore[attr-defined]
                 # Throw away users excluded from the directory.
                 users_with_profile = {
                     user_id: profile
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         for user_id in users_to_work_on:
             if await self.should_include_local_user_in_dir(user_id):
-                profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+                profile = await self.get_profileinfo(get_localpart_from_id(user_id))  # type: ignore[attr-defined]
                 await self.update_profile_in_user_dir(
                     user_id, profile.display_name, profile.avatar_url
                 )
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # technically it could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice sender can be
         # contacted.
-        if self.get_app_service_by_user_id(user) is not None:
+        if self.get_app_service_by_user_id(user) is not None:  # type: ignore[attr-defined]
             return False
 
         # We're opting to exclude appservice users (anyone matching the user
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # they could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice users can be
         # contacted.
-        if self.get_if_app_services_interested_in_user(user):
+        if self.get_if_app_services_interested_in_user(user):  # type: ignore[attr-defined]
             # TODO we might want to make this configurable for each app service
             return False
 
         # Support users are for diagnostics and should not appear in the user directory.
-        if await self.is_support_user(user):
+        if await self.is_support_user(user):  # type: ignore[attr-defined]
             return False
 
         # Deactivated users aren't contactable, so should not appear in the user directory.
         try:
-            if await self.get_user_deactivated_status(user):
+            if await self.get_user_deactivated_status(user):  # type: ignore[attr-defined]
                 return False
         except StoreError:
             # No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             (EventTypes.RoomHistoryVisibility, ""),
         )
 
-        current_state_ids = await self.get_filtered_current_state_ids(
+        current_state_ids = await self.get_filtered_current_state_ids(  # type: ignore[attr-defined]
             room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
         if join_rules_id:
-            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
+            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)  # type: ignore[attr-defined]
             if join_rule_ev:
                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
                     return True
 
         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
         if hist_vis_id:
-            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
+            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)  # type: ignore[attr-defined]
             if hist_vis_ev:
                 if (
                     hist_vis_ev.content.get("history_visibility")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76ac6..3af69a2076 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,11 +13,23 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateKey, StateMap
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import (
+    AbstractObservableDeferred,
+    ObservableDeferred,
+    yieldable_gather_results,
+)
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -37,7 +55,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
 MAX_STATE_DELTA_HOPS = 100
 
 
@@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             500000,
         )
 
+        # Current ongoing get_state_for_groups in-flight requests
+        # {group ID -> {StateFilter -> ObservableDeferred}}
+        self._state_group_inflight_requests: Dict[
+            int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
+        ] = {}
+
         def get_max_state_group_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
             return txn.fetchone()[0]  # type: ignore
@@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     async def _get_state_groups_from_groups(
-        self, groups: List[int], state_filter: StateFilter
+        self, groups: Sequence[int], state_filter: StateFilter
     ) -> Dict[int, StateMap[str]]:
         """Returns the state groups for a given set of groups from the
         database, filtering on types of state events.
@@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state_filter.filter_state(state_dict_ids), not missing_types
 
+    def _get_state_for_group_gather_inflight_requests(
+        self, group: int, state_filter_left_over: StateFilter
+    ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
+        """
+        Attempts to gather in-flight requests and re-use them to retrieve state
+        for the given state group, filtered with the given state filter.
+
+        Used as part of _get_state_for_group_using_inflight_cache.
+
+        Returns:
+            Tuple of two values:
+                A sequence of ObservableDeferreds to observe
+                A StateFilter representing what else needs to be requested to fulfill the request
+        """
+
+        inflight_requests = self._state_group_inflight_requests.get(group)
+        if inflight_requests is None:
+            # no requests for this group, need to retrieve it all ourselves
+            return (), state_filter_left_over
+
+        # The list of ongoing requests which will help narrow the current request.
+        reusable_requests = []
+        for (request_state_filter, request_deferred) in inflight_requests.items():
+            new_state_filter_left_over = state_filter_left_over.approx_difference(
+                request_state_filter
+            )
+            if new_state_filter_left_over == state_filter_left_over:
+                # Reusing this request would not gain us anything, so don't bother.
+                continue
+
+            reusable_requests.append(request_deferred)
+            state_filter_left_over = new_state_filter_left_over
+            if state_filter_left_over == StateFilter.none():
+                # we have managed to collect enough of the in-flight requests
+                # to cover our StateFilter and give us the state we need.
+                break
+
+        return reusable_requests, state_filter_left_over
+
+    async def _get_state_for_group_fire_request(
+        self, group: int, state_filter: StateFilter
+    ) -> StateMap[str]:
+        """
+        Fires off a request to get the state at a state group,
+        potentially filtering by type and/or state key.
+
+        This request will be tracked in the in-flight request cache and automatically
+        removed when it is finished.
+
+        Used as part of _get_state_for_group_using_inflight_cache.
+
+        Args:
+            group: ID of the state group for which we want to get state
+            state_filter: the state filter used to fetch state from the database
+        """
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        async def _the_request() -> StateMap[str]:
+            group_to_state_dict = await self._get_state_groups_from_groups(
+                (group,), state_filter=db_state_filter
+            )
+
+            # Now let's update the caches
+            self._insert_into_cache(
+                group_to_state_dict,
+                db_state_filter,
+                cache_seq_num_members=cache_sequence_m,
+                cache_seq_num_non_members=cache_sequence_nm,
+            )
+
+            # Remove ourselves from the in-flight cache
+            group_request_dict = self._state_group_inflight_requests[group]
+            del group_request_dict[db_state_filter]
+            if not group_request_dict:
+                # If there are no more requests in-flight for this group,
+                # clean up the cache by removing the empty dictionary
+                del self._state_group_inflight_requests[group]
+
+            return group_to_state_dict[group]
+
+        # We don't immediately await the result, so must use run_in_background
+        # But we DO await the result before the current log context (request)
+        # finishes, so don't need to run it as a background process.
+        request_deferred = run_in_background(_the_request)
+        observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
+
+        # Insert the ObservableDeferred into the cache
+        group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
+        group_request_dict[db_state_filter] = observable_deferred
+
+        return await make_deferred_yieldable(observable_deferred.observe())
+
+    async def _get_state_for_group_using_inflight_cache(
+        self, group: int, state_filter: StateFilter
+    ) -> MutableStateMap[str]:
+        """
+        Gets the state at a state group, potentially filtering by type and/or
+        state key.
+
+        1. Calls _get_state_for_group_gather_inflight_requests to gather any
+           ongoing requests which might overlap with the current request.
+        2. Fires a new request, using _get_state_for_group_fire_request,
+           for any state which cannot be gathered from ongoing requests.
+
+        Args:
+            group: ID of the state group for which we want to get state
+            state_filter: the state filter used to fetch state from the database
+        Returns:
+            state map
+        """
+
+        # first, figure out whether we can re-use any in-flight requests
+        # (and if so, what would be left over)
+        (
+            reusable_requests,
+            state_filter_left_over,
+        ) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
+
+        if state_filter_left_over != StateFilter.none():
+            # Fetch remaining state
+            remaining = await self._get_state_for_group_fire_request(
+                group, state_filter_left_over
+            )
+            assembled_state: MutableStateMap[str] = dict(remaining)
+        else:
+            assembled_state = {}
+
+        gathered = await make_deferred_yieldable(
+            defer.gatherResults(
+                (r.observe() for r in reusable_requests), consumeErrors=True
+            )
+        ).addErrback(unwrapFirstError)
+
+        # assemble our result.
+        for result_piece in gathered:
+            assembled_state.update(result_piece)
+
+        # Filter out any state that may be more than what we asked for.
+        return state_filter.filter_state(assembled_state)
+
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
@@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         if not incomplete_groups:
             return state
 
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
-
-        # Help the cache hit ratio by expanding the filter a bit
-        db_state_filter = state_filter.return_expanded()
-
-        group_to_state_dict = await self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
+        async def get_from_cache(group: int, state_filter: StateFilter) -> None:
+            state[group] = await self._get_state_for_group_using_inflight_cache(
+                group, state_filter
+            )
 
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
+        await yieldable_gather_results(
+            get_from_cache,
+            incomplete_groups,
+            state_filter,
         )
 
-        # And finally update the result dict, by filtering out any extra
-        # stuff we pulled out of the database.
-        for group, group_state_dict in group_to_state_dict.items():
-            # We just replace any existing entries, as we will have loaded
-            # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
-
         return state
 
     def _get_state_for_groups_using_cache(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 913448f0f9..e79ecf64a0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -204,13 +204,16 @@ class StateFilter:
         if get_all_members:
             # We want to return everything.
             return StateFilter.all()
-        else:
+        elif EventTypes.Member in self.types:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
                 types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
+        else:
+            # We want to return all non-members
+            return _ALL_NON_MEMBER_STATE_FILTER
 
     def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
         """Converts the filter to an SQL clause.
@@ -528,6 +531,9 @@ class StateFilter:
 
 
 _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+    types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
 _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)