summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 18ae8aee29..ea5cbdac08 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -16,6 +16,8 @@ import collections.abc
 import logging
 from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
 
+import attr
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, JsonMapping, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -43,6 +47,15 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventMetadata:
+    """Returned by `get_metadata_for_events`"""
+
+    room_id: str
+    event_type: str
+    state_key: Optional[str]
+
+
 def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
     v = KNOWN_ROOM_VERSIONS.get(room_version_id)
     if not v:
@@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return room_version
 
+    async def get_metadata_for_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventMetadata]:
+        """Get some metadata (room_id, type, state_key) for the given events.
+
+        This method is a faster alternative than fetching the full events from
+        the DB, and should be used when the full event is not needed.
+
+        Returns metadata for rejected and redacted events. Events that have not
+        been persisted are omitted from the returned dict.
+        """
+
+        def get_metadata_for_events_txn(
+            txn: LoggingTransaction,
+            batch_ids: Collection[str],
+        ) -> Dict[str, EventMetadata]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "e.event_id", batch_ids
+            )
+
+            sql = f"""
+                SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
+                LEFT JOIN state_events USING (event_id)
+                WHERE {clause}
+            """
+
+            txn.execute(sql, args)
+            return {
+                event_id: EventMetadata(
+                    room_id=room_id, event_type=event_type, state_key=state_key
+                )
+                for event_id, room_id, event_type, state_key in txn
+            }
+
+        result_map: Dict[str, EventMetadata] = {}
+        for batch_ids in batch_iter(event_ids, 1000):
+            result_map.update(
+                await self.db_pool.runInteraction(
+                    "get_metadata_for_events",
+                    get_metadata_for_events_txn,
+                    batch_ids=batch_ids,
+                )
+            )
+
+        return result_map
+
     async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
         """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.