summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/cache.py6
-rw-r--r--synapse/storage/databases/main/event_federation.py5
-rw-r--r--synapse/storage/databases/main/media_repository.py28
-rw-r--r--synapse/storage/databases/main/roommember.py83
-rw-r--r--synapse/storage/databases/main/state.py52
-rw-r--r--synapse/storage/databases/main/stream.py123
6 files changed, 279 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d6b75e47e..26b8e1a172 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -331,6 +331,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 "get_invited_rooms_for_local_user", (state_key,)
             )
             self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
+            self._attempt_to_invalidate_cache(
+                "_get_rooms_for_local_user_where_membership_is_inner", (state_key,)
+            )
 
             self._attempt_to_invalidate_cache(
                 "did_forget",
@@ -393,6 +396,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
         self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
         self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+        self._attempt_to_invalidate_cache(
+            "_get_rooms_for_local_user_where_membership_is_inner", None
+        )
         self._attempt_to_invalidate_cache("did_forget", None)
         self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
         self._attempt_to_invalidate_cache("get_references_for_event", None)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 24abab4a23..715846865b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
         last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)  # type: ignore[attr-defined]
+        if last_change is None:
+            # If the room isn't in the cache we know that the last change was
+            # somewhere before the earliest known position of the cache, so we
+            # can clamp to that.
+            last_change = self._events_stream_cache.get_earliest_known_position()  # type: ignore[attr-defined]
 
         # We don't always have a full stream_to_exterm_id table, e.g. after
         # the upgrade that introduced it, so we make sure we never ask for a
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 6128332af8..7617fd3ad4 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -64,6 +64,7 @@ class LocalMedia:
     quarantined_by: Optional[str]
     safe_from_quarantine: bool
     user_id: Optional[str]
+    authenticated: Optional[bool]
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -77,6 +78,7 @@ class RemoteMedia:
     created_ts: int
     last_access_ts: int
     quarantined_by: Optional[str]
+    authenticated: Optional[bool]
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -218,6 +220,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "last_access_ts",
                 "safe_from_quarantine",
                 "user_id",
+                "authenticated",
             ),
             allow_none=True,
             desc="get_local_media",
@@ -235,6 +238,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             last_access_ts=row[6],
             safe_from_quarantine=row[7],
             user_id=row[8],
+            authenticated=row[9],
         )
 
     async def get_local_media_by_user_paginate(
@@ -290,7 +294,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     last_access_ts,
                     quarantined_by,
                     safe_from_quarantine,
-                    user_id
+                    user_id,
+                    authenticated
                 FROM local_media_repository
                 WHERE user_id = ?
                 ORDER BY {order_by_column} {order}, media_id ASC
@@ -314,6 +319,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     quarantined_by=row[7],
                     safe_from_quarantine=bool(row[8]),
                     user_id=row[9],
+                    authenticated=row[10],
                 )
                 for row in txn
             ]
@@ -417,12 +423,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         time_now_ms: int,
         user_id: UserID,
     ) -> None:
+        if self.hs.config.media.enable_authenticated_media:
+            authenticated = True
+        else:
+            authenticated = False
+
         await self.db_pool.simple_insert(
             "local_media_repository",
             {
                 "media_id": media_id,
                 "created_ts": time_now_ms,
                 "user_id": user_id.to_string(),
+                "authenticated": authenticated,
             },
             desc="store_local_media_id",
         )
@@ -438,6 +450,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         user_id: UserID,
         url_cache: Optional[str] = None,
     ) -> None:
+        if self.hs.config.media.enable_authenticated_media:
+            authenticated = True
+        else:
+            authenticated = False
+
         await self.db_pool.simple_insert(
             "local_media_repository",
             {
@@ -448,6 +465,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "media_length": media_length,
                 "user_id": user_id.to_string(),
                 "url_cache": url_cache,
+                "authenticated": authenticated,
             },
             desc="store_local_media",
         )
@@ -638,6 +656,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "filesystem_id",
                 "last_access_ts",
                 "quarantined_by",
+                "authenticated",
             ),
             allow_none=True,
             desc="get_cached_remote_media",
@@ -654,6 +673,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             filesystem_id=row[4],
             last_access_ts=row[5],
             quarantined_by=row[6],
+            authenticated=row[7],
         )
 
     async def store_cached_remote_media(
@@ -666,6 +686,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         upload_name: Optional[str],
         filesystem_id: str,
     ) -> None:
+        if self.hs.config.media.enable_authenticated_media:
+            authenticated = True
+        else:
+            authenticated = False
+
         await self.db_pool.simple_insert(
             "remote_media_cache",
             {
@@ -677,6 +702,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "upload_name": upload_name,
                 "filesystem_id": filesystem_id,
                 "last_access_ts": time_now_ms,
+                "authenticated": authenticated,
             },
             desc="store_cached_remote_media",
         )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 5d2fd08495..640ab123f0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -279,8 +279,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
 
     @cached(max_entries=100000)  # type: ignore[synapse-@cached-mutable]
     async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
-        """Get the details of a room roughly suitable for use by the room
+        """
+        Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
+
+        Returns the total count of members in the room by membership type, and a
+        truncated list of members (the heroes). This will be the first 6 members of the
+        room:
+        - We want 5 heroes plus 1, in case one of them is the
+        calling user.
+        - They are ordered by `stream_ordering`, which are joined or
+        invited. When no joined or invited members are available, this also includes
+        banned and left users.
+
         Args:
             room_id: The room ID to query
         Returns:
@@ -308,23 +319,36 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             for count, membership in txn:
                 res.setdefault(membership, MemberSummary([], count))
 
-            # we order by membership and then fairly arbitrarily by event_id so
-            # heroes are consistent
-            # Note, rejected events will have a null membership field, so
-            # we we manually filter them out.
+            # Order by membership (joins -> invites -> leave (former insiders) ->
+            # everything else (outsiders like bans/knocks), then by `stream_ordering` so
+            # the first members in the room show up first and to make the sort stable
+            # (consistent heroes).
+            #
+            # Note: rejected events will have a null membership field, so we we manually
+            # filter them out.
             sql = """
                 SELECT state_key, membership, event_id
                 FROM current_state_events
                 WHERE type = 'm.room.member' AND room_id = ?
                     AND membership IS NOT NULL
                 ORDER BY
-                    CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                    event_id ASC
+                    CASE membership WHEN ? THEN 1 WHEN ? THEN 2 WHEN ? THEN 3 ELSE 4 END ASC,
+                    event_stream_ordering ASC
                 LIMIT ?
             """
 
-            # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
-            txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+            txn.execute(
+                sql,
+                (
+                    room_id,
+                    # Sort order
+                    Membership.JOIN,
+                    Membership.INVITE,
+                    Membership.LEAVE,
+                    # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+                    6,
+                ),
+            )
             for user_id, membership, event_id in txn:
                 summary = res[membership]
                 # we will always have a summary for this membership type at this
@@ -421,9 +445,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         if not membership_list:
             return []
 
-        rooms = await self.db_pool.runInteraction(
-            "get_rooms_for_local_user_where_membership_is",
-            self._get_rooms_for_local_user_where_membership_is_txn,
+        # Convert membership list to frozen set as a) it needs to be hashable,
+        # and b) we don't care about the order.
+        membership_list = frozenset(membership_list)
+
+        rooms = await self._get_rooms_for_local_user_where_membership_is_inner(
             user_id,
             membership_list,
         )
@@ -442,6 +468,24 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
 
         return [room for room in rooms if room.room_id not in rooms_to_exclude]
 
+    @cached(max_entries=1000, tree=True)
+    async def _get_rooms_for_local_user_where_membership_is_inner(
+        self,
+        user_id: str,
+        membership_list: Collection[str],
+    ) -> Sequence[RoomsForUser]:
+        if not membership_list:
+            return []
+
+        rooms = await self.db_pool.runInteraction(
+            "get_rooms_for_local_user_where_membership_is",
+            self._get_rooms_for_local_user_where_membership_is_txn,
+            user_id,
+            membership_list,
+        )
+
+        return rooms
+
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
         txn: LoggingTransaction,
@@ -1509,10 +1553,19 @@ def extract_heroes_from_room_summary(
 ) -> List[str]:
     """Determine the users that represent a room, from the perspective of the `me` user.
 
+    This function expects `MemberSummary.members` to already be sorted by
+    `stream_ordering` like the results from `get_room_summary(...)`.
+
     The rules which say which users we select are specified in the "Room Summary"
     section of
     https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync
 
+
+    Args:
+        details: Mapping from membership type to member summary. We expect
+            `MemberSummary.members` to already be sorted by `stream_ordering`.
+        me: The user for whom we are determining the heroes for.
+
     Returns a list (possibly empty) of heroes' mxids.
     """
     empty_ms = MemberSummary([], 0)
@@ -1527,11 +1580,11 @@ def extract_heroes_from_room_summary(
         r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
     ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
 
-    # FIXME: order by stream ordering rather than as returned by SQL
+    # We expect `MemberSummary.members` to already be sorted by `stream_ordering`
     if joined_user_ids or invited_user_ids:
-        return sorted(joined_user_ids + invited_user_ids)[0:5]
+        return (joined_user_ids + invited_user_ids)[0:5]
     else:
-        return sorted(gone_user_ids)[0:5]
+        return gone_user_ids[0:5]
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b2a67aff89..5188b2f7a4 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -41,7 +41,7 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
@@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         create_event = await self.get_event(create_id)
         return create_event
 
+    @cached(max_entries=10000)
+    async def get_room_type(self, room_id: str) -> Optional[str]:
+        """Get the room type for a given room. The server must be joined to the
+        given room.
+        """
+
+        row = await self.db_pool.simple_select_one(
+            table="room_stats_state",
+            keyvalues={"room_id": room_id},
+            retcols=("room_type",),
+            allow_none=True,
+            desc="get_room_type",
+        )
+
+        if row is not None:
+            return row[0]
+
+        # If we haven't updated `room_stats_state` with the room yet, query the
+        # create event directly.
+        create_event = await self.get_create_event_for_room(room_id)
+        room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+        return room_type
+
+    @cachedList(cached_method_name="get_room_type", list_name="room_ids")
+    async def bulk_get_room_type(
+        self, room_ids: Set[str]
+    ) -> Mapping[str, Optional[str]]:
+        """Bulk fetch room types for the given rooms, the server must be in all
+        the rooms given.
+        """
+
+        rows = await self.db_pool.simple_select_many_batch(
+            table="room_stats_state",
+            column="room_id",
+            iterable=room_ids,
+            retcols=("room_id", "room_type"),
+            desc="bulk_get_room_type",
+        )
+
+        # If we haven't updated `room_stats_state` with the room yet, query the
+        # create events directly. This should happen only rarely so we don't
+        # mind if we do this in a loop.
+        results = dict(rows)
+        for room_id in room_ids - results.keys():
+            create_event = await self.get_create_event_for_room(room_id)
+            room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+            results[room_id] = room_type
+
+        return results
+
     @cached(max_entries=100000, iterable=True)
     async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
         """Get the current state event ids for a room based on the
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b034361aec 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             get_last_event_pos_in_room_before_stream_ordering_txn,
         )
 
+    async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+        self,
+        room_ids: StrCollection,
+        end_token: RoomStreamToken,
+    ) -> Dict[str, int]:
+        """Bulk fetch the stream position of the latest events in the given
+        rooms
+        """
+
+        min_token = end_token.stream
+        max_token = end_token.get_max_stream_pos()
+        results: Dict[str, int] = {}
+
+        # First, we check for the rooms in the stream change cache to see if we
+        # can just use the latest position from it.
+        missing_room_ids: Set[str] = set()
+        for room_id in room_ids:
+            stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+            if stream_pos and stream_pos <= min_token:
+                results[room_id] = stream_pos
+            else:
+                missing_room_ids.add(room_id)
+
+        # Next, we query the stream position from the DB. At first we fetch all
+        # positions less than the *max* stream pos in the token, then filter
+        # them down. We do this as a) this is a cheaper query, and b) the vast
+        # majority of rooms will have a latest token from before the min stream
+        # pos.
+
+        def bulk_get_last_event_pos_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            # This query fetches the latest stream position in the rooms before
+            # the given max position.
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, (
+                    SELECT stream_ordering FROM events AS e
+                    LEFT JOIN rejections USING (event_id)
+                    WHERE e.room_id = r.room_id
+                        AND stream_ordering <= ?
+                        AND NOT outlier
+                        AND rejection_reason IS NULL
+                    ORDER BY stream_ordering DESC
+                    LIMIT 1
+                )
+                FROM rooms AS r
+                WHERE {clause}
+            """
+            txn.execute(sql, [max_token] + args)
+            return {row[0]: row[1] for row in txn}
+
+        recheck_rooms: Set[str] = set()
+        for batched in batch_iter(missing_room_ids, 1000):
+            result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering",
+                bulk_get_last_event_pos_txn,
+                batched,
+            )
+
+            # Check that the stream position for the rooms are from before the
+            # minimum position of the token. If not then we need to fetch more
+            # rows.
+            for room_id, stream in result.items():
+                if stream <= min_token:
+                    results[room_id] = stream
+                else:
+                    recheck_rooms.add(room_id)
+
+        if not recheck_rooms:
+            return results
+
+        # For the remaining rooms we need to fetch all rows between the min and
+        # max stream positions in the end token, and filter out the rows that
+        # are after the end token.
+        #
+        # This query should be fast as the range between the min and max should
+        # be small.
+
+        def bulk_get_last_event_pos_recheck_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, instance_name, stream_ordering
+                FROM events
+                WHERE ? < stream_ordering AND stream_ordering <= ?
+                    AND NOT outlier
+                    AND rejection_reason IS NULL
+                    AND {clause}
+                ORDER BY stream_ordering ASC
+            """
+            txn.execute(sql, [min_token, max_token] + args)
+
+            # We take the max stream ordering that is less than the token. Since
+            # we ordered by stream ordering we just need to iterate through and
+            # take the last matching stream ordering.
+            txn_results: Dict[str, int] = {}
+            for row in txn:
+                room_id = row[0]
+                event_pos = PersistedEventPosition(row[1], row[2])
+                if not event_pos.persisted_after(end_token):
+                    txn_results[room_id] = event_pos.stream
+
+            return txn_results
+
+        for batched in batch_iter(recheck_rooms, 1000):
+            recheck_result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+                bulk_get_last_event_pos_recheck_txn,
+                batched,
+            )
+            results.update(recheck_result)
+
+        return results
+
     async def get_current_room_stream_token_for_room_id(
         self, room_id: str
     ) -> RoomStreamToken: