summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py3
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/event_push_actions.py16
-rw-r--r--synapse/storage/databases/main/stream.py20
4 files changed, 37 insertions, 4 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 9c06c837dc..f3f9c6d54c 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -270,6 +270,9 @@ class MockHomeserver:
     def get_instance_name(self) -> str:
         return "master"
 
+    def should_send_federation(self) -> bool:
+        return False
+
 
 class Porter:
     def __init__(
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 57aaf778ec..a3d31d3737 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -87,7 +87,6 @@ class DataStore(
     RoomStore,
     RoomBatchStore,
     RegistrationStore,
-    StreamWorkerStore,
     ProfileStore,
     PresenceStore,
     TransactionWorkerStore,
@@ -112,6 +111,7 @@ class DataStore(
     SearchStore,
     TagsStore,
     AccountDataStore,
+    StreamWorkerStore,
     OpenIdStore,
     ClientIpWorkerStore,
     DeviceStore,
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 80ca2fd0b6..eae41d7484 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -25,8 +25,8 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -122,7 +122,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore):
+class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -218,7 +218,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
                 retcol="event_id",
             )
 
-            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)  # type: ignore[attr-defined]
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
@@ -307,12 +307,22 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
         actions that have been deleted from `event_push_actions` table.
         """
 
+        # If there have been no events in the room since the stream ordering,
+        # there can't be any push actions either.
+        if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
+            return 0, 0
+
         clause = ""
         args = [user_id, room_id, stream_ordering]
         if max_stream_ordering is not None:
             clause = "AND ea.stream_ordering <= ?"
             args.append(max_stream_ordering)
 
+            # If the max stream ordering is less than the min stream ordering,
+            # then obviously there are zero push actions in that range.
+            if max_stream_ordering <= stream_ordering:
+                return 0, 0
+
         sql = f"""
             SELECT
                COUNT(CASE WHEN notif = 1 THEN 1 END),
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8e88784d3c..3a1df7776c 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -46,10 +46,12 @@ from typing import (
     Set,
     Tuple,
     cast,
+    overload,
 )
 
 import attr
 from frozendict import frozendict
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -795,6 +797,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         return RoomStreamToken(topo, stream_ordering)
 
+    @overload
+    def get_stream_id_for_event_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        allow_none: Literal[False] = False,
+    ) -> int:
+        ...
+
+    @overload
+    def get_stream_id_for_event_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        ...
+
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,