summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/devices.py65
-rw-r--r--synapse/storage/databases/main/events.py20
-rw-r--r--synapse/storage/databases/main/events_worker.py24
-rw-r--r--synapse/storage/databases/main/relations.py78
-rw-r--r--synapse/storage/databases/main/room.py31
-rw-r--r--synapse/storage/databases/main/state.py50
-rw-r--r--synapse/storage/databases/main/stream.py26
7 files changed, 174 insertions, 120 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dc8009b23d..318e4df376 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1582,7 +1582,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self,
         user_id: str,
         device_ids: Collection[str],
-        hosts: Optional[Collection[str]],
         room_ids: Collection[str],
     ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
@@ -1592,9 +1591,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user whose device changed.
             device_ids: The IDs of any changed devices. If empty, this function will
                 return None.
-            hosts: The remote destinations that should be notified of the change. If
-                None then the set of hosts have *not* been calculated, and will be
-                calculated later by a background task.
             room_ids: The rooms that the user is in
 
         Returns:
@@ -1606,14 +1602,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         context = get_active_span_text_map()
 
-        def add_device_changes_txn(
-            txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
-        ):
+        def add_device_changes_txn(txn, stream_ids):
             self._add_device_change_to_stream_txn(
                 txn,
                 user_id,
                 device_ids,
-                stream_ids_for_device_change,
+                stream_ids,
             )
 
             self._add_device_outbound_room_poke_txn(
@@ -1621,43 +1615,17 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 user_id,
                 device_ids,
                 room_ids,
-                stream_ids_for_device_change,
-                context,
-                hosts_have_been_calculated=hosts is not None,
-            )
-
-            # If the set of hosts to send to has not been calculated yet (and so
-            # `hosts` is None) or there are no `hosts` to send to, then skip
-            # trying to persist them to the DB.
-            if not hosts:
-                return
-
-            self._add_device_outbound_poke_to_stream_txn(
-                txn,
-                user_id,
-                device_ids,
-                hosts,
-                stream_ids_for_outbound_pokes,
+                stream_ids,
                 context,
             )
 
-        # `device_lists_stream` wants a stream ID per device update.
-        num_stream_ids = len(device_ids)
-
-        if hosts:
-            # `device_lists_outbound_pokes` wants a different stream ID for
-            # each row, which is a row per host per device update.
-            num_stream_ids += len(hosts) * len(device_ids)
-
-        async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
-            stream_ids_for_device_change = stream_ids[: len(device_ids)]
-            stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
-
+        async with self._device_list_id_gen.get_next_mult(
+            len(device_ids)
+        ) as stream_ids:
             await self.db_pool.runInteraction(
                 "add_device_change_to_stream",
                 add_device_changes_txn,
-                stream_ids_for_device_change,
-                stream_ids_for_outbound_pokes,
+                stream_ids,
             )
 
         return stream_ids[-1]
@@ -1735,7 +1703,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     next(stream_id_iterator),
                     user_id,
                     device_id,
-                    False,
+                    not self.hs.is_mine_id(
+                        user_id
+                    ),  # We only need to send out update for *our* users
                     now,
                     encoded_context if whitelisted_homeserver(destination) else "{}",
                 )
@@ -1752,19 +1722,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         room_ids: Collection[str],
         stream_ids: List[str],
         context: Dict[str, str],
-        hosts_have_been_calculated: bool,
     ) -> None:
-        """Record the user in the room has updated their device.
-
-        Args:
-            hosts_have_been_calculated: True if `device_lists_outbound_pokes`
-                has been updated already with the updates.
-        """
-
-        # We only need to convert to outbound pokes if they are our user.
-        converted_to_destinations = (
-            hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
-        )
+        """Record the user in the room has updated their device."""
 
         encoded_context = json_encoder.encode(context)
 
@@ -1789,7 +1748,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     device_id,
                     room_id,
                     stream_id,
-                    converted_to_destinations,
+                    False,
                     encoded_context,
                 )
                 for room_id in room_ids
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3fcd5f5b99..2a1e567ce0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -963,6 +963,21 @@ class PersistEventsStore:
                 values=to_insert,
             )
 
+    async def update_current_state(
+        self,
+        room_id: str,
+        state_delta: DeltaState,
+        stream_id: int,
+    ) -> None:
+        """Update the current state stored in the datatabase for the given room"""
+
+        await self.db_pool.runInteraction(
+            "update_current_state",
+            self._update_current_state_txn,
+            state_delta_by_room={room_id: state_delta},
+            stream_id=stream_id,
+        )
+
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
@@ -1819,10 +1834,7 @@ class PersistEventsStore:
         if rel_type == RelationTypes.REPLACE:
             txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
-        if (
-            rel_type == RelationTypes.THREAD
-            or rel_type == RelationTypes.UNSTABLE_THREAD
-        ):
+        if rel_type == RelationTypes.THREAD:
             txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a60e3f4fdd..5288cdba03 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1979,3 +1979,27 @@ class EventsWorkerStore(SQLBaseStore):
             desc="is_partial_state_event",
         )
         return result is not None
+
+    async def get_partial_state_events_batch(self, room_id: str) -> List[str]:
+        """Get a list of events in the given room that have partial state"""
+        return await self.db_pool.runInteraction(
+            "get_partial_state_events_batch",
+            self._get_partial_state_events_batch_txn,
+            room_id,
+        )
+
+    @staticmethod
+    def _get_partial_state_events_batch_txn(
+        txn: LoggingTransaction, room_id: str
+    ) -> List[str]:
+        txn.execute(
+            """
+            SELECT event_id FROM partial_state_events AS pse
+                JOIN events USING (event_id)
+            WHERE pse.room_id = ?
+            ORDER BY events.stream_ordering
+            LIMIT 100
+            """,
+            (room_id,),
+        )
+        return [row[0] for row in txn]
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 407158ceee..a5c31f6787 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,7 +14,6 @@
 
 import logging
 from typing import (
-    TYPE_CHECKING,
     Collection,
     Dict,
     FrozenSet,
@@ -32,20 +31,12 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import (
-    DatabasePool,
-    LoggingDatabaseConnection,
-    LoggingTransaction,
-    make_in_list_sql_clause,
-)
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.types import JsonDict, RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
 logger = logging.getLogger(__name__)
 
 
@@ -63,16 +54,6 @@ class _RelatedEvent:
 
 
 class RelationsWorkerStore(SQLBaseStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-
-        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
-
     @cached(uncached_args=("event",), tree=True)
     async def get_relations_for_event(
         self,
@@ -497,7 +478,7 @@ class RelationsWorkerStore(SQLBaseStore):
                         AND parent.room_id = child.room_id
                     WHERE
                         %s
-                        AND %s
+                        AND relation_type = ?
                     ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
                 """
             else:
@@ -512,22 +493,16 @@ class RelationsWorkerStore(SQLBaseStore):
                         AND parent.room_id = child.room_id
                     WHERE
                         %s
-                        AND %s
+                        AND relation_type = ?
                     ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
                 """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", event_ids
             )
+            args.append(RelationTypes.THREAD)
 
-            if self._msc3440_enabled:
-                relations_clause = "(relation_type = ? OR relation_type = ?)"
-                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
-            else:
-                relations_clause = "relation_type = ?"
-                args.append(RelationTypes.THREAD)
-
-            txn.execute(sql % (clause, relations_clause), args)
+            txn.execute(sql % (clause,), args)
             latest_event_ids = {}
             for parent_event_id, child_event_id in txn:
                 # Only consider the latest threaded reply (by topological ordering).
@@ -547,7 +522,7 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND parent.room_id = child.room_id
                 WHERE
                     %s
-                    AND %s
+                    AND relation_type = ?
                 GROUP BY parent.event_id
             """
 
@@ -556,15 +531,9 @@ class RelationsWorkerStore(SQLBaseStore):
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", latest_event_ids.keys()
             )
+            args.append(RelationTypes.THREAD)
 
-            if self._msc3440_enabled:
-                relations_clause = "(relation_type = ? OR relation_type = ?)"
-                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
-            else:
-                relations_clause = "relation_type = ?"
-                args.append(RelationTypes.THREAD)
-
-            txn.execute(sql % (clause, relations_clause), args)
+            txn.execute(sql % (clause,), args)
             counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
 
             return counts, latest_event_ids
@@ -622,7 +591,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 parent.event_id = relates_to_id
                 AND parent.room_id = child.room_id
             WHERE
-                %s
+                relation_type = ?
                 AND %s
                 AND %s
             GROUP BY parent.event_id, child.sender
@@ -638,16 +607,9 @@ class RelationsWorkerStore(SQLBaseStore):
                 txn.database_engine, "relates_to_id", event_ids
             )
 
-            if self._msc3440_enabled:
-                relations_clause = "(relation_type = ? OR relation_type = ?)"
-                relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
-            else:
-                relations_clause = "relation_type = ?"
-                relations_args = [RelationTypes.THREAD]
-
             txn.execute(
-                sql % (users_sql, events_clause, relations_clause),
-                users_args + events_args + relations_args,
+                sql % (users_sql, events_clause),
+                [RelationTypes.THREAD] + users_args + events_args,
             )
             return {(row[0], row[1]): row[2] for row in txn}
 
@@ -677,7 +639,7 @@ class RelationsWorkerStore(SQLBaseStore):
             user participated in that event's thread, otherwise false.
         """
 
-        def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
+        def _get_threads_participated_txn(txn: LoggingTransaction) -> Set[str]:
             # Fetch whether the requester has participated or not.
             sql = """
                 SELECT DISTINCT relates_to_id
@@ -688,28 +650,20 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND parent.room_id = child.room_id
                 WHERE
                     %s
-                    AND %s
+                    AND relation_type = ?
                     AND child.sender = ?
             """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", event_ids
             )
+            args.extend([RelationTypes.THREAD, user_id])
 
-            if self._msc3440_enabled:
-                relations_clause = "(relation_type = ? OR relation_type = ?)"
-                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
-            else:
-                relations_clause = "relation_type = ?"
-                args.append(RelationTypes.THREAD)
-
-            args.append(user_id)
-
-            txn.execute(sql % (clause, relations_clause), args)
+            txn.execute(sql % (clause,), args)
             return {row[0] for row in txn.fetchall()}
 
         participated_threads = await self.db_pool.runInteraction(
-            "get_thread_summary", _get_thread_summary_txn
+            "get_threads_participated", _get_threads_participated_txn
         )
 
         return {event_id: event_id in participated_threads for event_id in event_ids}
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 18b1acd9e1..87e9482c60 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1077,6 +1077,37 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_rooms_for_retention_period_in_range_txn,
         )
 
+    async def clear_partial_state_room(self, room_id: str) -> bool:
+        # this can race with incoming events, so we watch out for FK errors.
+        # TODO(faster_joins): this still doesn't completely fix the race, since the persist process
+        #   is not atomic. I fear we need an application-level lock.
+        try:
+            await self.db_pool.runInteraction(
+                "clear_partial_state_room", self._clear_partial_state_room_txn, room_id
+            )
+            return True
+        except self.db_pool.engine.module.DatabaseError as e:
+            # TODO(faster_joins): how do we distinguish between FK errors and other errors?
+            logger.warning(
+                "Exception while clearing lazy partial-state-room %s, retrying: %s",
+                room_id,
+                e,
+            )
+            return False
+
+    @staticmethod
+    def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None:
+        DatabasePool.simple_delete_txn(
+            txn,
+            table="partial_state_rooms_servers",
+            keyvalues={"room_id": room_id},
+        )
+        DatabasePool.simple_delete_one_txn(
+            txn,
+            table="partial_state_rooms",
+            keyvalues={"room_id": room_id},
+        )
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ecdc1fdc4c..7a1b013fa3 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -129,7 +130,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         if room_version is None:
-            raise NotFoundError("Could not room_version for %s" % (room_id,))
+            raise NotFoundError("Could not find room_version for %s" % (room_id,))
 
         return room_version
 
@@ -354,6 +355,53 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return {row["state_group"] for row in rows}
 
+    async def update_state_for_partial_state_event(
+        self,
+        event: EventBase,
+        context: EventContext,
+    ) -> None:
+        """Update the state group for a partial state event"""
+        await self.db_pool.runInteraction(
+            "update_state_for_partial_state_event",
+            self._update_state_for_partial_state_event_txn,
+            event,
+            context,
+        )
+
+    def _update_state_for_partial_state_event_txn(
+        self,
+        txn,
+        event: EventBase,
+        context: EventContext,
+    ):
+        # we shouldn't have any outliers here
+        assert not event.internal_metadata.is_outlier()
+
+        # anything that was rejected should have the same state as its
+        # predecessor.
+        if context.rejected:
+            assert context.state_group == context.state_group_before_event
+
+        self.db_pool.simple_update_txn(
+            txn,
+            table="event_to_state_groups",
+            keyvalues={"event_id": event.event_id},
+            updatevalues={"state_group": context.state_group},
+        )
+
+        self.db_pool.simple_delete_one_txn(
+            txn,
+            table="partial_state_events",
+            keyvalues={"event_id": event.event_id},
+        )
+
+        # TODO(faster_joins): need to do something about workers here
+        txn.call_after(
+            self._get_state_group_for_event.prefill,
+            (event.event_id,),
+            context.state_group,
+        )
+
 
 class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 6d45a8a9f6..793e906630 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -758,6 +758,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "get_room_event_before_stream_ordering", _f
         )
 
+    async def get_last_event_in_room_before_stream_ordering(
+        self,
+        room_id: str,
+        end_token: RoomStreamToken,
+    ) -> Optional[EventBase]:
+        """Returns the last event in a room at or before a stream ordering
+
+        Args:
+            room_id
+            end_token: The token used to stream from
+
+        Returns:
+            The most recent event.
+        """
+
+        last_row = await self.get_room_event_before_stream_ordering(
+            room_id=room_id,
+            stream_ordering=end_token.stream,
+        )
+        if last_row:
+            _, _, event_id = last_row
+            event = await self.get_event(event_id, get_prev_content=True)
+            return event
+
+        return None
+
     async def get_current_room_stream_token_for_room_id(
         self, room_id: Optional[str] = None
     ) -> RoomStreamToken: