summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/databases/main/room.py300
1 files changed, 218 insertions, 82 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index e41c99027a..78906a5e1d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1,5 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019, 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -50,8 +50,14 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import IdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    IdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -97,6 +103,12 @@ class RoomSortOrder(Enum):
     STATE_EVENTS = "state_events"
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PartialStateResyncInfo:
+    joined_via: Optional[str]
+    servers_in_room: List[str] = attr.ib(factory=list)
+
+
 class RoomWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -108,6 +120,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         self.config: HomeServerConfig = hs.config
 
+        self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+
+        if isinstance(database.engine, PostgresEngine):
+            self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="un_partial_stated_room_stream",
+                instance_name=self._instance_name,
+                tables=[
+                    ("un_partial_stated_room_stream", "instance_name", "stream_id")
+                ],
+                sequence_name="un_partial_stated_room_stream_sequence",
+                # TODO(faster_joins, multiple writers) Support multiple writers.
+                writers=["master"],
+            )
+        else:
+            self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
+                db_conn, "un_partial_stated_room_stream", "stream_id"
+            )
+
     async def store_room(
         self,
         room_id: str,
@@ -906,7 +938,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 event_json = db_to_json(content_json)
                 content = event_json["content"]
                 content_url = content.get("url")
-                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+                info = content.get("info")
+                if isinstance(info, dict):
+                    thumbnail_url = info.get("thumbnail_url")
+                else:
+                    thumbnail_url = None
 
                 for url in (content_url, thumbnail_url):
                     if not url:
@@ -1160,17 +1196,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             desc="get_partial_state_servers_at_join",
         )
 
-    async def get_partial_state_rooms_and_servers(
+    async def get_partial_state_room_resync_info(
         self,
-    ) -> Mapping[str, Collection[str]]:
-        """Get all rooms containing events with partial state, and the servers known
-        to be in the room.
+    ) -> Mapping[str, PartialStateResyncInfo]:
+        """Get all rooms containing events with partial state, and the information
+        needed to restart a "resync" of those rooms.
 
         Returns:
             A dictionary of rooms with partial state, with room IDs as keys and
             lists of servers in rooms as values.
         """
-        room_servers: Dict[str, List[str]] = {}
+        room_servers: Dict[str, PartialStateResyncInfo] = {}
+
+        rows = await self.db_pool.simple_select_list(
+            table="partial_state_rooms",
+            keyvalues={},
+            retcols=("room_id", "joined_via"),
+            desc="get_server_which_served_partial_join",
+        )
+
+        for row in rows:
+            room_id = row["room_id"]
+            joined_via = row["joined_via"]
+            room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
 
         rows = await self.db_pool.simple_select_list(
             "partial_state_rooms_servers",
@@ -1182,74 +1230,18 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         for row in rows:
             room_id = row["room_id"]
             server_name = row["server_name"]
-            room_servers.setdefault(room_id, []).append(server_name)
+            entry = room_servers.get(room_id)
+            if entry is None:
+                # There is a foreign key constraint which enforces that every room_id in
+                # partial_state_rooms_servers appears in partial_state_rooms. So we
+                # expect `entry` to be non-null. (This reasoning fails if we've
+                # partial-joined between the two SELECTs, but this is unlikely to happen
+                # in practice.)
+                continue
+            entry.servers_in_room.append(server_name)
 
         return room_servers
 
-    async def clear_partial_state_room(self, room_id: str) -> bool:
-        """Clears the partial state flag for a room.
-
-        Args:
-            room_id: The room whose partial state flag is to be cleared.
-
-        Returns:
-            `True` if the partial state flag has been cleared successfully.
-
-            `False` if the partial state flag could not be cleared because the room
-            still contains events with partial state.
-        """
-        try:
-            await self.db_pool.runInteraction(
-                "clear_partial_state_room", self._clear_partial_state_room_txn, room_id
-            )
-            return True
-        except self.db_pool.engine.module.IntegrityError as e:
-            # Assume that any `IntegrityError`s are due to partial state events.
-            logger.info(
-                "Exception while clearing lazy partial-state-room %s, retrying: %s",
-                room_id,
-                e,
-            )
-            return False
-
-    def _clear_partial_state_room_txn(
-        self, txn: LoggingTransaction, room_id: str
-    ) -> None:
-        DatabasePool.simple_delete_txn(
-            txn,
-            table="partial_state_rooms_servers",
-            keyvalues={"room_id": room_id},
-        )
-        DatabasePool.simple_delete_one_txn(
-            txn,
-            table="partial_state_rooms",
-            keyvalues={"room_id": room_id},
-        )
-        self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
-        self._invalidate_cache_and_stream(
-            txn, self.get_partial_state_servers_at_join, (room_id,)
-        )
-
-        # We now delete anything from `device_lists_remote_pending` with a
-        # stream ID less than the minimum
-        # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
-        device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
-            txn,
-            table="partial_state_rooms",
-            keyvalues={},
-            retcol="MIN(device_lists_stream_id)",
-            allow_none=True,
-        )
-        if device_lists_stream_id is None:
-            # There are no rooms being currently partially joined, so we delete everything.
-            txn.execute("DELETE FROM device_lists_remote_pending")
-        else:
-            sql = """
-                DELETE FROM device_lists_remote_pending
-                WHERE stream_id <= ?
-            """
-            txn.execute(sql, (device_lists_stream_id,))
-
     @cached()
     async def is_partial_state_room(self, room_id: str) -> bool:
         """Checks if this room has partial state.
@@ -1285,6 +1277,66 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         )
         return result["join_event_id"], result["device_lists_stream_id"]
 
+    def get_un_partial_stated_rooms_token(self) -> int:
+        # TODO(faster_joins, multiple writers): This is inappropriate if there
+        #     are multiple writers because workers that don't write often will
+        #     hold all readers up.
+        #     (See `MultiWriterIdGenerator.get_persisted_upto_position` for an
+        #      explanation.)
+        return self._un_partial_stated_rooms_stream_id_gen.get_current_token()
+
+    async def get_un_partial_stated_rooms_from_stream(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
+        """Get updates for caches replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+
+        if last_id == current_id:
+            return [], current_id, False
+
+        def get_un_partial_stated_rooms_from_stream_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
+            sql = """
+                SELECT stream_id, room_id
+                FROM un_partial_stated_room_stream
+                WHERE ? < stream_id AND stream_id <= ? AND instance_name = ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, instance_name, limit))
+            updates = [(row[0], (row[1],)) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db_pool.runInteraction(
+            "get_un_partial_stated_rooms_from_stream",
+            get_un_partial_stated_rooms_from_stream_txn,
+        )
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1776,6 +1828,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
 
+        self._instance_name = hs.get_instance_name()
+
     async def upsert_room_on_join(
         self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
     ) -> None:
@@ -1817,9 +1871,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
                 "creator": room_creator,
                 "has_auth_chain_index": has_auth_chain_index,
             },
-            # rooms has a unique constraint on room_id, so no need to lock when doing an
-            # emulated upsert.
-            lock=False,
         )
 
     async def store_partial_state_room(
@@ -1827,6 +1878,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         room_id: str,
         servers: Collection[str],
         device_lists_stream_id: int,
+        joined_via: str,
     ) -> None:
         """Mark the given room as containing events with partial state.
 
@@ -1842,6 +1894,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             servers: other servers known to be in the room
             device_lists_stream_id: the device_lists stream ID at the time when we first
                 joined the room.
+            joined_via: the server name we requested a partial join from.
         """
         await self.db_pool.runInteraction(
             "store_partial_state_room",
@@ -1849,6 +1902,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             room_id,
             servers,
             device_lists_stream_id,
+            joined_via,
         )
 
     def _store_partial_state_room_txn(
@@ -1857,6 +1911,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         room_id: str,
         servers: Collection[str],
         device_lists_stream_id: int,
+        joined_via: str,
     ) -> None:
         DatabasePool.simple_insert_txn(
             txn,
@@ -1866,6 +1921,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
                 "device_lists_stream_id": device_lists_stream_id,
                 # To be updated later once the join event is persisted.
                 "join_event_id": None,
+                "joined_via": joined_via,
             },
         )
         DatabasePool.simple_insert_many_txn(
@@ -1935,9 +1991,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
                 "creator": "",
                 "has_auth_chain_index": has_auth_chain_index,
             },
-            # rooms has a unique constraint on room_id, so no need to lock when doing an
-            # emulated upsert.
-            lock=False,
         )
 
     async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
@@ -2026,7 +2079,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         Args:
             report_id: ID of reported event in database
         Returns:
-            event_report: json list of information from event report
+            JSON dict of information from an event report or None if the
+            report does not exist.
         """
 
         def _get_event_report_txn(
@@ -2099,8 +2153,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             user_id: search for user_id. Ignored if user_id is None
             room_id: search for room_id. Ignored if room_id is None
         Returns:
-            event_reports: json list of event reports
-            count: total number of event reports matching the filter criteria
+            Tuple of:
+                json list of event reports
+                total number of event reports matching the filter criteria
         """
 
         def _get_event_reports_paginate_txn(
@@ -2239,3 +2294,84 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             self.is_room_blocked,
             (room_id,),
         )
+
+    async def clear_partial_state_room(self, room_id: str) -> bool:
+        """Clears the partial state flag for a room.
+
+        Args:
+            room_id: The room whose partial state flag is to be cleared.
+
+        Returns:
+            `True` if the partial state flag has been cleared successfully.
+
+            `False` if the partial state flag could not be cleared because the room
+            still contains events with partial state.
+        """
+        try:
+            async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id:
+                await self.db_pool.runInteraction(
+                    "clear_partial_state_room",
+                    self._clear_partial_state_room_txn,
+                    room_id,
+                    un_partial_state_room_stream_id,
+                )
+                return True
+        except self.db_pool.engine.module.IntegrityError as e:
+            # Assume that any `IntegrityError`s are due to partial state events.
+            logger.info(
+                "Exception while clearing lazy partial-state-room %s, retrying: %s",
+                room_id,
+                e,
+            )
+            return False
+
+    def _clear_partial_state_room_txn(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        un_partial_state_room_stream_id: int,
+    ) -> None:
+        DatabasePool.simple_delete_txn(
+            txn,
+            table="partial_state_rooms_servers",
+            keyvalues={"room_id": room_id},
+        )
+        DatabasePool.simple_delete_one_txn(
+            txn,
+            table="partial_state_rooms",
+            keyvalues={"room_id": room_id},
+        )
+        self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+        self._invalidate_cache_and_stream(
+            txn, self.get_partial_state_servers_at_join, (room_id,)
+        )
+
+        DatabasePool.simple_insert_txn(
+            txn,
+            "un_partial_stated_room_stream",
+            {
+                "stream_id": un_partial_state_room_stream_id,
+                "instance_name": self._instance_name,
+                "room_id": room_id,
+            },
+        )
+
+        # We now delete anything from `device_lists_remote_pending` with a
+        # stream ID less than the minimum
+        # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
+        device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
+            txn,
+            table="partial_state_rooms",
+            keyvalues={},
+            retcol="MIN(device_lists_stream_id)",
+            allow_none=True,
+        )
+        if device_lists_stream_id is None:
+            # There are no rooms being currently partially joined, so we delete everything.
+            txn.execute("DELETE FROM device_lists_remote_pending")
+        else:
+            sql = """
+                DELETE FROM device_lists_remote_pending
+                WHERE stream_id <= ?
+            """
+            txn.execute(sql, (device_lists_stream_id,))