summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py85
1 files changed, 70 insertions, 15 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5db70f9a60..161aad0f89 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -80,7 +80,7 @@ import attr
 
 from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -259,7 +259,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             txn,
             user_id,
             room_id,
-            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+            receipt_types=(
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ),
         )
 
         stream_ordering = None
@@ -448,6 +452,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by ascending stream_ordering.
             The list will have between 0~limit entries.
         """
+
         # find rooms that have a read receipt in them and return the next
         # push actions
         def get_after_receipt(
@@ -455,7 +460,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         ) -> List[Tuple[str, str, int, str, bool]]:
             # find rooms that have a read receipt in them and return the next
             # push actions
-            sql = """
+
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight
                 FROM (
@@ -463,10 +479,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                         MAX(stream_ordering) as stream_ordering
                     FROM events
                     INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE receipt_type = 'm.read' AND user_id = ?
+                    WHERE {receipt_types_clause} AND user_id = ?
                     GROUP BY room_id
                 ) AS rl,
-                    event_push_actions AS ep
+                event_push_actions AS ep
                 WHERE
                     ep.room_id = rl.room_id
                     AND ep.stream_ordering > rl.stream_ordering
@@ -476,7 +492,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
+            )
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
@@ -490,7 +508,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         def get_no_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            sql = """
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight
                 FROM event_push_actions AS ep
@@ -498,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 WHERE
                     ep.room_id NOT IN (
                         SELECT room_id FROM receipts_linearized
-                        WHERE receipt_type = 'm.read' AND user_id = ?
+                        WHERE {receipt_types_clause} AND user_id = ?
                         GROUP BY room_id
                     )
                     AND ep.user_id = ?
@@ -507,7 +535,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
+            )
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
@@ -557,12 +587,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by descending received_ts.
             The list will have between 0~limit entries.
         """
+
         # find rooms that have a read receipt in them and return the most recent
         # push actions
         def get_after_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = """
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight, e.received_ts
                 FROM (
@@ -570,7 +611,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                         MAX(stream_ordering) as stream_ordering
                     FROM events
                     INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE receipt_type = 'm.read' AND user_id = ?
+                    WHERE {receipt_types_clause} AND user_id = ?
                     GROUP BY room_id
                 ) AS rl,
                 event_push_actions AS ep
@@ -584,7 +625,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
+            )
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
@@ -598,7 +641,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         def get_no_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = """
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight, e.received_ts
                 FROM event_push_actions AS ep
@@ -606,7 +659,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 WHERE
                     ep.room_id NOT IN (
                         SELECT room_id FROM receipts_linearized
-                        WHERE receipt_type = 'm.read' AND user_id = ?
+                        WHERE {receipt_types_clause} AND user_id = ?
                         GROUP BY room_id
                     )
                     AND ep.user_id = ?
@@ -615,7 +668,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
+            )
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())