diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a9f25a5904..0ce3156c9c 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
+ self._attempt_to_invalidate_cache("get_threads", (room_id,))
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 060fe71454..6698cbf664 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr
from prometheus_client import Counter
import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
@@ -1616,7 +1616,7 @@ class PersistEventsStore:
)
# Remove from relations table.
- self._handle_redact_relations(txn, event.redacts)
+ self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1866,6 +1866,34 @@ class PersistEventsStore:
},
)
+ if relation.rel_type == RelationTypes.THREAD:
+ # Upsert into the threads table, but only overwrite the value if the
+ # new event is of a later topological order OR if the topological
+ # ordering is equal, but the stream ordering is later.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (room_id, thread_id)
+ DO UPDATE SET
+ latest_event_id = excluded.latest_event_id,
+ topological_ordering = excluded.topological_ordering,
+ stream_ordering = excluded.stream_ordering
+ WHERE
+ threads.topological_ordering <= excluded.topological_ordering AND
+ threads.stream_ordering < excluded.stream_ordering
+ """
+
+ txn.execute(
+ sql,
+ (
+ event.room_id,
+ relation.parent_id,
+ event.event_id,
+ event.depth,
+ event.internal_metadata.stream_ordering,
+ ),
+ )
+
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
@@ -1989,13 +2017,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
def _handle_redact_relations(
- self, txn: LoggingTransaction, redacted_event_id: str
+ self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.
Args:
txn
+ room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""
@@ -2024,6 +2053,9 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_threads, (room_id,)
+ )
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e7fbf950e6..ac9b96ab44 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,6 +14,7 @@
import logging
from typing import (
+ TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
@@ -29,18 +30,47 @@ from typing import (
import attr
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThreadsNextBatch:
+ topological_ordering: int
+ stream_ordering: int
+
+ def __str__(self) -> str:
+ return f"{self.topological_ordering}_{self.stream_ordering}"
+
+ @classmethod
+ def from_string(cls, string: str) -> "ThreadsNextBatch":
+ """
+ Creates a ThreadsNextBatch from its textual representation.
+ """
+ try:
+ keys = (int(s) for s in string.split("_"))
+ return cls(*keys)
+ except Exception:
+ raise SynapseError(400, "Invalid threads token")
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
Contains enough information about a related event in order to properly filter
@@ -56,6 +86,76 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_update_handler(
+ "threads_backfill", self._backfill_threads
+ )
+
+ async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
+ """Backfill the threads table."""
+
+ def threads_backfill_txn(txn: LoggingTransaction) -> int:
+ last_thread_id = progress.get("last_thread_id", "")
+
+ # Get the latest event in each thread by topo ordering / stream ordering.
+ #
+ # Note that the MAX(event_id) is needed to abide by the rules of group by,
+ # but doesn't actually do anything since there should only be a single event
+ # ID per topo/stream ordering pair.
+ sql = f"""
+ SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id > ? AND
+ relation_type = '{RelationTypes.THREAD}'
+ GROUP BY room_id, relates_to_id
+ ORDER BY relates_to_id
+ LIMIT ?
+ """
+ txn.execute(sql, (last_thread_id, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # Insert the rows into the threads table. If a matching thread already exists,
+ # assume it is from a newer event.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
+ VALUES %s
+ ON CONFLICT (room_id, thread_id)
+ DO NOTHING
+ """
+ if isinstance(txn.database_engine, PostgresEngine):
+ txn.execute_values(sql % ("?",), rows, fetch=False)
+ else:
+ txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows)
+
+ # Mark the progress.
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
+ )
+
+ return txn.rowcount
+
+ result = await self.db_pool.runInteraction(
+ "threads_backfill", threads_backfill_txn
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("threads_backfill")
+
+ return result
+
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
@@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ @cached(tree=True)
+ async def get_threads(
+ self,
+ room_id: str,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ """Get a list of thread IDs, ordered by topological ordering of their
+ latest reply.
+
+ Args:
+ room_id: The room the event belongs to.
+ limit: Only fetch the most recent `limit` threads.
+ from_token: Fetch rows from a previous next_batch, or from the start if None.
+
+ Returns:
+ A tuple of:
+ A list of thread root event IDs.
+
+ The next_batch, if one exists.
+ """
+ # Generate the pagination clause, if necessary.
+ #
+ # Find any threads where the latest reply is equal / before the last
+ # thread's topo ordering and earlier in stream ordering.
+ pagination_clause = ""
+ pagination_args: tuple = ()
+ if from_token:
+ pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
+ pagination_args = (
+ from_token.topological_ordering,
+ from_token.stream_ordering,
+ )
+
+ sql = f"""
+ SELECT thread_id, topological_ordering, stream_ordering
+ FROM threads
+ WHERE
+ room_id = ?
+ {pagination_clause}
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT ?
+ """
+
+ def _get_threads_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ txn.execute(sql, (room_id, *pagination_args, limit + 1))
+
+ rows = cast(List[Tuple[str, int, int]], txn.fetchall())
+ thread_ids = [r[0] for r in rows]
+
+ # If there are more events, generate the next pagination key from the
+ # last thread which will be returned.
+ next_token = None
+ if len(thread_ids) > limit:
+ last_topo_id = rows[-2][1]
+ last_stream_id = rows[-2][2]
+ next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
+
+ return thread_ids[:limit], next_token
+
+ return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
@cached()
async def get_thread_id(self, event_id: str) -> str:
"""
diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql
new file mode 100644
index 0000000000..aa7c5e9a2e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/09threads_table.sql
@@ -0,0 +1,30 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE threads (
+ room_id TEXT NOT NULL,
+ -- The event ID of the root event in the thread.
+ thread_id TEXT NOT NULL,
+ -- The latest event ID and corresponding topo / stream ordering.
+ latest_event_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL,
+ CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id)
+);
+
+CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering);
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7309, 'threads_backfill', '{}');
|