diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d6b75e47e..26b8e1a172 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -331,6 +331,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_invited_rooms_for_local_user", (state_key,)
)
self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
+ self._attempt_to_invalidate_cache(
+ "_get_rooms_for_local_user_where_membership_is_inner", (state_key,)
+ )
self._attempt_to_invalidate_cache(
"did_forget",
@@ -393,6 +396,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+ self._attempt_to_invalidate_cache(
+ "_get_rooms_for_local_user_where_membership_is_inner", None
+ )
self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("get_references_for_event", None)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 24abab4a23..715846865b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
+ if last_change is None:
+ # If the room isn't in the cache we know that the last change was
+ # somewhere before the earliest known position of the cache, so we
+ # can clamp to that.
+ last_change = self._events_stream_cache.get_earliest_known_position() # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 6128332af8..7617fd3ad4 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -64,6 +64,7 @@ class LocalMedia:
quarantined_by: Optional[str]
safe_from_quarantine: bool
user_id: Optional[str]
+ authenticated: Optional[bool]
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -77,6 +78,7 @@ class RemoteMedia:
created_ts: int
last_access_ts: int
quarantined_by: Optional[str]
+ authenticated: Optional[bool]
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -218,6 +220,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"last_access_ts",
"safe_from_quarantine",
"user_id",
+ "authenticated",
),
allow_none=True,
desc="get_local_media",
@@ -235,6 +238,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
last_access_ts=row[6],
safe_from_quarantine=row[7],
user_id=row[8],
+ authenticated=row[9],
)
async def get_local_media_by_user_paginate(
@@ -290,7 +294,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
last_access_ts,
quarantined_by,
safe_from_quarantine,
- user_id
+ user_id,
+ authenticated
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@@ -314,6 +319,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
quarantined_by=row[7],
safe_from_quarantine=bool(row[8]),
user_id=row[9],
+ authenticated=row[10],
)
for row in txn
]
@@ -417,12 +423,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
time_now_ms: int,
user_id: UserID,
) -> None:
+ if self.hs.config.media.enable_authenticated_media:
+ authenticated = True
+ else:
+ authenticated = False
+
await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
"created_ts": time_now_ms,
"user_id": user_id.to_string(),
+ "authenticated": authenticated,
},
desc="store_local_media_id",
)
@@ -438,6 +450,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id: UserID,
url_cache: Optional[str] = None,
) -> None:
+ if self.hs.config.media.enable_authenticated_media:
+ authenticated = True
+ else:
+ authenticated = False
+
await self.db_pool.simple_insert(
"local_media_repository",
{
@@ -448,6 +465,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"media_length": media_length,
"user_id": user_id.to_string(),
"url_cache": url_cache,
+ "authenticated": authenticated,
},
desc="store_local_media",
)
@@ -638,6 +656,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"filesystem_id",
"last_access_ts",
"quarantined_by",
+ "authenticated",
),
allow_none=True,
desc="get_cached_remote_media",
@@ -654,6 +673,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
+ authenticated=row[7],
)
async def store_cached_remote_media(
@@ -666,6 +686,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name: Optional[str],
filesystem_id: str,
) -> None:
+ if self.hs.config.media.enable_authenticated_media:
+ authenticated = True
+ else:
+ authenticated = False
+
await self.db_pool.simple_insert(
"remote_media_cache",
{
@@ -677,6 +702,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name": upload_name,
"filesystem_id": filesystem_id,
"last_access_ts": time_now_ms,
+ "authenticated": authenticated,
},
desc="store_cached_remote_media",
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 5d2fd08495..640ab123f0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -279,8 +279,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
@cached(max_entries=100000) # type: ignore[synapse-@cached-mutable]
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
- """Get the details of a room roughly suitable for use by the room
+ """
+ Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
+
+ Returns the total count of members in the room by membership type, and a
+ truncated list of members (the heroes). This will be the first 6 members of the
+ room:
+ - We want 5 heroes plus 1, in case one of them is the
+ calling user.
+ - They are ordered by `stream_ordering`, which are joined or
+ invited. When no joined or invited members are available, this also includes
+ banned and left users.
+
Args:
room_id: The room ID to query
Returns:
@@ -308,23 +319,36 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for count, membership in txn:
res.setdefault(membership, MemberSummary([], count))
- # we order by membership and then fairly arbitrarily by event_id so
- # heroes are consistent
- # Note, rejected events will have a null membership field, so
- # we we manually filter them out.
+ # Order by membership (joins -> invites -> leave (former insiders) ->
+ # everything else (outsiders like bans/knocks), then by `stream_ordering` so
+ # the first members in the room show up first and to make the sort stable
+ # (consistent heroes).
+ #
+ # Note: rejected events will have a null membership field, so we we manually
+ # filter them out.
sql = """
SELECT state_key, membership, event_id
FROM current_state_events
WHERE type = 'm.room.member' AND room_id = ?
AND membership IS NOT NULL
ORDER BY
- CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- event_id ASC
+ CASE membership WHEN ? THEN 1 WHEN ? THEN 2 WHEN ? THEN 3 ELSE 4 END ASC,
+ event_stream_ordering ASC
LIMIT ?
"""
- # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
- txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+ txn.execute(
+ sql,
+ (
+ room_id,
+ # Sort order
+ Membership.JOIN,
+ Membership.INVITE,
+ Membership.LEAVE,
+ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+ 6,
+ ),
+ )
for user_id, membership, event_id in txn:
summary = res[membership]
# we will always have a summary for this membership type at this
@@ -421,9 +445,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
if not membership_list:
return []
- rooms = await self.db_pool.runInteraction(
- "get_rooms_for_local_user_where_membership_is",
- self._get_rooms_for_local_user_where_membership_is_txn,
+ # Convert membership list to frozen set as a) it needs to be hashable,
+ # and b) we don't care about the order.
+ membership_list = frozenset(membership_list)
+
+ rooms = await self._get_rooms_for_local_user_where_membership_is_inner(
user_id,
membership_list,
)
@@ -442,6 +468,24 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
return [room for room in rooms if room.room_id not in rooms_to_exclude]
+ @cached(max_entries=1000, tree=True)
+ async def _get_rooms_for_local_user_where_membership_is_inner(
+ self,
+ user_id: str,
+ membership_list: Collection[str],
+ ) -> Sequence[RoomsForUser]:
+ if not membership_list:
+ return []
+
+ rooms = await self.db_pool.runInteraction(
+ "get_rooms_for_local_user_where_membership_is",
+ self._get_rooms_for_local_user_where_membership_is_txn,
+ user_id,
+ membership_list,
+ )
+
+ return rooms
+
def _get_rooms_for_local_user_where_membership_is_txn(
self,
txn: LoggingTransaction,
@@ -1509,10 +1553,19 @@ def extract_heroes_from_room_summary(
) -> List[str]:
"""Determine the users that represent a room, from the perspective of the `me` user.
+ This function expects `MemberSummary.members` to already be sorted by
+ `stream_ordering` like the results from `get_room_summary(...)`.
+
The rules which say which users we select are specified in the "Room Summary"
section of
https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync
+
+ Args:
+ details: Mapping from membership type to member summary. We expect
+ `MemberSummary.members` to already be sorted by `stream_ordering`.
+ me: The user for whom we are determining the heroes for.
+
Returns a list (possibly empty) of heroes' mxids.
"""
empty_ms = MemberSummary([], 0)
@@ -1527,11 +1580,11 @@ def extract_heroes_from_room_summary(
r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
- # FIXME: order by stream ordering rather than as returned by SQL
+ # We expect `MemberSummary.members` to already be sorted by `stream_ordering`
if joined_user_ids or invited_user_ids:
- return sorted(joined_user_ids + invited_user_ids)[0:5]
+ return (joined_user_ids + invited_user_ids)[0:5]
else:
- return sorted(gone_user_ids)[0:5]
+ return gone_user_ids[0:5]
@attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b2a67aff89..5188b2f7a4 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -41,7 +41,7 @@ from typing import (
import attr
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
@@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_event(create_id)
return create_event
+ @cached(max_entries=10000)
+ async def get_room_type(self, room_id: str) -> Optional[str]:
+ """Get the room type for a given room. The server must be joined to the
+ given room.
+ """
+
+ row = await self.db_pool.simple_select_one(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ retcols=("room_type",),
+ allow_none=True,
+ desc="get_room_type",
+ )
+
+ if row is not None:
+ return row[0]
+
+ # If we haven't updated `room_stats_state` with the room yet, query the
+ # create event directly.
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ return room_type
+
+ @cachedList(cached_method_name="get_room_type", list_name="room_ids")
+ async def bulk_get_room_type(
+ self, room_ids: Set[str]
+ ) -> Mapping[str, Optional[str]]:
+ """Bulk fetch room types for the given rooms, the server must be in all
+ the rooms given.
+ """
+
+ rows = await self.db_pool.simple_select_many_batch(
+ table="room_stats_state",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id", "room_type"),
+ desc="bulk_get_room_type",
+ )
+
+ # If we haven't updated `room_stats_state` with the room yet, query the
+ # create events directly. This should happen only rarely so we don't
+ # mind if we do this in a loop.
+ results = dict(rows)
+ for room_id in room_ids - results.keys():
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ results[room_id] = room_type
+
+ return results
+
@cached(max_entries=100000, iterable=True)
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b034361aec 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
get_last_event_pos_in_room_before_stream_ordering_txn,
)
+ async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+ self,
+ room_ids: StrCollection,
+ end_token: RoomStreamToken,
+ ) -> Dict[str, int]:
+ """Bulk fetch the stream position of the latest events in the given
+ rooms
+ """
+
+ min_token = end_token.stream
+ max_token = end_token.get_max_stream_pos()
+ results: Dict[str, int] = {}
+
+ # First, we check for the rooms in the stream change cache to see if we
+ # can just use the latest position from it.
+ missing_room_ids: Set[str] = set()
+ for room_id in room_ids:
+ stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+ if stream_pos and stream_pos <= min_token:
+ results[room_id] = stream_pos
+ else:
+ missing_room_ids.add(room_id)
+
+ # Next, we query the stream position from the DB. At first we fetch all
+ # positions less than the *max* stream pos in the token, then filter
+ # them down. We do this as a) this is a cheaper query, and b) the vast
+ # majority of rooms will have a latest token from before the min stream
+ # pos.
+
+ def bulk_get_last_event_pos_txn(
+ txn: LoggingTransaction, batch_room_ids: StrCollection
+ ) -> Dict[str, int]:
+ # This query fetches the latest stream position in the rooms before
+ # the given max position.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+ sql = f"""
+ SELECT room_id, (
+ SELECT stream_ordering FROM events AS e
+ LEFT JOIN rejections USING (event_id)
+ WHERE e.room_id = r.room_id
+ AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejection_reason IS NULL
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ )
+ FROM rooms AS r
+ WHERE {clause}
+ """
+ txn.execute(sql, [max_token] + args)
+ return {row[0]: row[1] for row in txn}
+
+ recheck_rooms: Set[str] = set()
+ for batched in batch_iter(missing_room_ids, 1000):
+ result = await self.db_pool.runInteraction(
+ "bulk_get_last_event_pos_in_room_before_stream_ordering",
+ bulk_get_last_event_pos_txn,
+ batched,
+ )
+
+ # Check that the stream position for the rooms are from before the
+ # minimum position of the token. If not then we need to fetch more
+ # rows.
+ for room_id, stream in result.items():
+ if stream <= min_token:
+ results[room_id] = stream
+ else:
+ recheck_rooms.add(room_id)
+
+ if not recheck_rooms:
+ return results
+
+ # For the remaining rooms we need to fetch all rows between the min and
+ # max stream positions in the end token, and filter out the rows that
+ # are after the end token.
+ #
+ # This query should be fast as the range between the min and max should
+ # be small.
+
+ def bulk_get_last_event_pos_recheck_txn(
+ txn: LoggingTransaction, batch_room_ids: StrCollection
+ ) -> Dict[str, int]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+ sql = f"""
+ SELECT room_id, instance_name, stream_ordering
+ FROM events
+ WHERE ? < stream_ordering AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejection_reason IS NULL
+ AND {clause}
+ ORDER BY stream_ordering ASC
+ """
+ txn.execute(sql, [min_token, max_token] + args)
+
+ # We take the max stream ordering that is less than the token. Since
+ # we ordered by stream ordering we just need to iterate through and
+ # take the last matching stream ordering.
+ txn_results: Dict[str, int] = {}
+ for row in txn:
+ room_id = row[0]
+ event_pos = PersistedEventPosition(row[1], row[2])
+ if not event_pos.persisted_after(end_token):
+ txn_results[room_id] = event_pos.stream
+
+ return txn_results
+
+ for batched in batch_iter(recheck_rooms, 1000):
+ recheck_result = await self.db_pool.runInteraction(
+ "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+ bulk_get_last_event_pos_recheck_txn,
+ batched,
+ )
+ results.update(recheck_result)
+
+ return results
+
async def get_current_room_stream_token_for_room_id(
self, room_id: str
) -> RoomStreamToken:
|