diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 407158ceee..a5c31f6787 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,7 +14,6 @@
import logging
from typing import (
- TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
@@ -32,20 +31,12 @@ import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- make_in_list_sql_clause,
-)
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
logger = logging.getLogger(__name__)
@@ -63,16 +54,6 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._msc3440_enabled = hs.config.experimental.msc3440_enabled
-
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
@@ -497,7 +478,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND %s
+ AND relation_type = ?
ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
"""
else:
@@ -512,22 +493,16 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND %s
+ AND relation_type = ?
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
+ args.append(RelationTypes.THREAD)
- if self._msc3440_enabled:
- relations_clause = "(relation_type = ? OR relation_type = ?)"
- args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
- else:
- relations_clause = "relation_type = ?"
- args.append(RelationTypes.THREAD)
-
- txn.execute(sql % (clause, relations_clause), args)
+ txn.execute(sql % (clause,), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
@@ -547,7 +522,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND %s
+ AND relation_type = ?
GROUP BY parent.event_id
"""
@@ -556,15 +531,9 @@ class RelationsWorkerStore(SQLBaseStore):
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)
+ args.append(RelationTypes.THREAD)
- if self._msc3440_enabled:
- relations_clause = "(relation_type = ? OR relation_type = ?)"
- args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
- else:
- relations_clause = "relation_type = ?"
- args.append(RelationTypes.THREAD)
-
- txn.execute(sql % (clause, relations_clause), args)
+ txn.execute(sql % (clause,), args)
counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
return counts, latest_event_ids
@@ -622,7 +591,7 @@ class RelationsWorkerStore(SQLBaseStore):
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
- %s
+ relation_type = ?
AND %s
AND %s
GROUP BY parent.event_id, child.sender
@@ -638,16 +607,9 @@ class RelationsWorkerStore(SQLBaseStore):
txn.database_engine, "relates_to_id", event_ids
)
- if self._msc3440_enabled:
- relations_clause = "(relation_type = ? OR relation_type = ?)"
- relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
- else:
- relations_clause = "relation_type = ?"
- relations_args = [RelationTypes.THREAD]
-
txn.execute(
- sql % (users_sql, events_clause, relations_clause),
- users_args + events_args + relations_args,
+ sql % (users_sql, events_clause),
+ [RelationTypes.THREAD] + users_args + events_args,
)
return {(row[0], row[1]): row[2] for row in txn}
@@ -677,7 +639,7 @@ class RelationsWorkerStore(SQLBaseStore):
user participated in that event's thread, otherwise false.
"""
- def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
+ def _get_threads_participated_txn(txn: LoggingTransaction) -> Set[str]:
# Fetch whether the requester has participated or not.
sql = """
SELECT DISTINCT relates_to_id
@@ -688,28 +650,20 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND %s
+ AND relation_type = ?
AND child.sender = ?
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
+ args.extend([RelationTypes.THREAD, user_id])
- if self._msc3440_enabled:
- relations_clause = "(relation_type = ? OR relation_type = ?)"
- args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
- else:
- relations_clause = "relation_type = ?"
- args.append(RelationTypes.THREAD)
-
- args.append(user_id)
-
- txn.execute(sql % (clause, relations_clause), args)
+ txn.execute(sql % (clause,), args)
return {row[0] for row in txn.fetchall()}
participated_threads = await self.db_pool.runInteraction(
- "get_thread_summary", _get_thread_summary_txn
+ "get_threads_participated", _get_threads_participated_txn
)
return {event_id: event_id in participated_threads for event_id in event_ids}
|