diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8d94aeaa32..a75386f6a0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -236,7 +236,7 @@ class BulkPushRuleEvaluator:
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
- thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
+ thread_id = await self.store.get_thread_id(relation.parent_id)
# It's possible that old room versions have non-integer power levels (floats or
# strings). Workaround this by explicitly converting to int.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 6b7eec4bf2..e7fbf950e6 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -28,7 +28,7 @@ from typing import (
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
@@ -777,7 +777,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
- async def get_thread_id(self, event_id: str) -> Optional[str]:
+ async def get_thread_id(self, event_id: str) -> str:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
@@ -787,7 +787,7 @@ class RelationsWorkerStore(SQLBaseStore):
Returns:
The event ID of the root event in the thread, if this event is part
- of a thread. None, otherwise.
+ of a thread. "main", otherwise.
"""
# Since event relations form a tree, we should only ever find 0 or 1
# results from the below query.
@@ -802,13 +802,15 @@ class RelationsWorkerStore(SQLBaseStore):
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
"""
- def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
+ def _get_thread_id(txn: LoggingTransaction) -> str:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
- return None
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
|