| diff --git a/changelog.d/16756.misc b/changelog.d/16756.misc
new file mode 100644
index 0000000000..200e18fb7b
--- /dev/null
+++ b/changelog.d/16756.misc
@@ -0,0 +1 @@
+Improve DB performance of calculating badge counts for push.
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
 index 03ce0b4dc6..cce9583fa7 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -28,17 +28,11 @@ from synapse.storage.databases.main import DataStore
 
 async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
     invites = await store.get_invited_rooms_for_local_user(user_id)
-    joins = await store.get_rooms_for_user(user_id)
 
     badge = len(invites)
 
     room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
-    for room_id, notify_count in room_to_count.items():
-        # room_to_count may include rooms which the user has left,
-        # ignore those.
-        if room_id not in joins:
-            continue
-
+    for _room_id, notify_count in room_to_count.items():
         if notify_count == 0:
             continue
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
 index 650b8c8135..6d4e2942ea 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -357,10 +357,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         This function is intentionally not cached because it is called to calculate the
         unread badge for push notifications and thus the result is expected to change.
 
-        Note that this function assumes the user is a member of the room. Because
-        summary rows are not removed when a user leaves a room, the caller must
-        filter out those results from the result.
-
         Returns:
             A map of room ID to notification counts for the given user.
         """
@@ -373,127 +369,170 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
     def _get_unread_counts_by_room_for_user_txn(
         self, txn: LoggingTransaction, user_id: str
     ) -> Dict[str, int]:
-        receipt_types_clause, args = make_in_list_sql_clause(
+        # To get the badge count of all rooms we need to make three queries:
+        #   1. Fetch all counts from `event_push_summary`, discarding any stale
+        #      rooms.
+        #   2. Fetch all notifications from `event_push_actions` that haven't
+        #      been rotated yet.
+        #   3. Fetch all notifications from `event_push_actions` for the stale
+        #      rooms.
+        #
+        # The "stale room" scenario generally happens when there is a new read
+        # receipt that hasn't yet been processed to update the
+        # `event_push_summary` table. When that happens we ignore the
+        # `event_push_summary` table for that room and calculate the count
+        # manually from `event_push_actions`.
+
+        # We need to only take into account read receipts of these types.
+        receipt_types_clause, receipt_types_args = make_in_list_sql_clause(
             self.database_engine,
             "receipt_type",
             (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
-        args.extend([user_id, user_id])
-
-        receipts_cte = f"""
-            WITH all_receipts AS (
-                SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
-                FROM receipts_linearized
-                LEFT JOIN events USING (room_id, event_id)
-                WHERE
-                    {receipt_types_clause}
-                    AND user_id = ?
-                GROUP BY room_id, thread_id
-            )
-        """
-
-        receipts_joins = """
-            LEFT JOIN (
-                SELECT room_id, thread_id,
-                max_receipt_stream_ordering AS threaded_receipt_stream_ordering
-                FROM all_receipts
-                WHERE thread_id IS NOT NULL
-            ) AS threaded_receipts USING (room_id, thread_id)
-            LEFT JOIN (
-                SELECT room_id, thread_id,
-                max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
-                FROM all_receipts
-                WHERE thread_id IS NULL
-            ) AS unthreaded_receipts USING (room_id)
-        """
-
-        # First get summary counts by room / thread for the user. We use the max receipt
-        # stream ordering of both threaded & unthreaded receipts to compare against the
-        # summary table.
-        #
-        # PostgreSQL and SQLite differ in comparing scalar numerics.
-        if isinstance(self.database_engine, PostgresEngine):
-            # GREATEST ignores NULLs.
-            max_clause = """GREATEST(
-                threaded_receipt_stream_ordering,
-                unthreaded_receipt_stream_ordering
-            )"""
-        else:
-            # MAX returns NULL if any are NULL, so COALESCE to 0 first.
-            max_clause = """MAX(
-                COALESCE(threaded_receipt_stream_ordering, 0),
-                COALESCE(unthreaded_receipt_stream_ordering, 0)
-            )"""
 
+        # Step 1, fetch all counts from `event_push_summary` for the user. This
+        # is slightly convoluted as we also need to pull out the stream ordering
+        # of the most recent receipt of the user in the room (either a thread
+        # aware receipt or thread unaware receipt) in order to determine
+        # whether the row in `event_push_summary` is stale. Hence the outer
+        # GROUP BY and odd join condition against `receipts_linearized`.
         sql = f"""
-            {receipts_cte}
-            SELECT eps.room_id, eps.thread_id, notif_count
-            FROM event_push_summary AS eps
-            {receipts_joins}
-            WHERE user_id = ?
-                AND notif_count != 0
-                AND (
-                    (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
-                    OR last_receipt_stream_ordering = {max_clause}
+            SELECT room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering,
+                MAX(receipt_stream_ordering)
+            FROM (
+                SELECT e.room_id, notif_count, e.stream_ordering, e.thread_id, last_receipt_stream_ordering,
+                    ev.stream_ordering AS receipt_stream_ordering
+                FROM event_push_summary AS e
+                INNER JOIN local_current_membership USING (user_id, room_id)
+                LEFT JOIN receipts_linearized AS r ON (
+                    e.user_id = r.user_id
+                    AND e.room_id = r.room_id
+                    AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+                    AND {receipt_types_clause}
                 )
+                LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+                WHERE e.user_id = ? and notif_count > 0
+            ) AS es
+            GROUP BY room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering
         """
-        txn.execute(sql, args)
-
-        seen_thread_ids = set()
-        room_to_count: Dict[str, int] = defaultdict(int)
 
-        for room_id, thread_id, notif_count in txn:
-            room_to_count[room_id] += notif_count
-            seen_thread_ids.add(thread_id)
+        txn.execute(
+            sql,
+            receipt_types_args
+            + [
+                user_id,
+            ],
+        )
 
-        # Now get any event push actions that haven't been rotated using the same OR
-        # join and filter by receipt and event push summary rotated up to stream ordering.
-        sql = f"""
-            {receipts_cte}
-            SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
-            FROM event_push_actions AS epa
-            {receipts_joins}
-            WHERE user_id = ?
-                AND epa.notif = 1
-                AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
-                AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
-                AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
-            GROUP BY epa.room_id, epa.thread_id
-        """
-        txn.execute(sql, args)
+        room_to_count: Dict[str, int] = defaultdict(int)
+        stale_room_ids = set()
+        for row in txn:
+            room_id = row[0]
+            notif_count = row[1]
+            stream_ordering = row[2]
+            _thread_id = row[3]
+            last_receipt_stream_ordering = row[4]
+            receipt_stream_ordering = row[5]
+
+            if last_receipt_stream_ordering is None:
+                if receipt_stream_ordering is None:
+                    room_to_count[room_id] += notif_count
+                elif stream_ordering > receipt_stream_ordering:
+                    room_to_count[room_id] += notif_count
+                else:
+                    # The latest read receipt from the user is after all the rows for
+                    # this room in `event_push_summary`. We ignore them, and
+                    # calculate the count from `event_push_actions` in step 3.
+                    pass
+            elif last_receipt_stream_ordering == receipt_stream_ordering:
+                room_to_count[room_id] += notif_count
+            else:
+                # The row is stale if `last_receipt_stream_ordering` is set and
+                # *doesn't* match the latest receipt from the user.
+                stale_room_ids.add(room_id)
 
-        for room_id, thread_id, notif_count in txn:
-            # Note: only count push actions we have valid summaries for with up to date receipt.
-            if thread_id not in seen_thread_ids:
-                continue
-            room_to_count[room_id] += notif_count
+        # Discard any stale rooms from `room_to_count`, as we will recalculate
+        # them in step 3.
+        for room_id in stale_room_ids:
+            room_to_count.pop(room_id, None)
 
-        thread_id_clause, thread_ids_args = make_in_list_sql_clause(
-            self.database_engine, "epa.thread_id", seen_thread_ids
+        # Step 2, basically the same query, except against `event_push_actions`
+        # and only fetching rows inserted since the last rotation.
+        rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
         )
 
-        # Finally re-check event_push_actions for any rooms not in the summary, ignoring
-        # the rotated up-to position. This handles the case where a read receipt has arrived
-        # but not been rotated meaning the summary table is out of date, so we go back to
-        # the push actions table.
         sql = f"""
-            {receipts_cte}
-            SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
-            FROM event_push_actions AS epa
-            {receipts_joins}
-            WHERE user_id = ?
-            AND NOT {thread_id_clause}
-            AND epa.notif = 1
-            AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
-            AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
-            GROUP BY epa.room_id
+            SELECT room_id, thread_id
+            FROM (
+                SELECT e.room_id, e.stream_ordering, e.thread_id,
+                    ev.stream_ordering AS receipt_stream_ordering
+                FROM event_push_actions AS e
+                INNER JOIN local_current_membership USING (user_id, room_id)
+                LEFT JOIN receipts_linearized AS r ON (
+                    e.user_id = r.user_id
+                    AND e.room_id = r.room_id
+                    AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+                    AND {receipt_types_clause}
+                )
+                LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+                WHERE e.user_id = ? and notif > 0
+                    AND e.stream_ordering > ?
+            ) AS es
+            GROUP BY room_id, stream_ordering, thread_id
+            HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0)
         """
 
-        args.extend(thread_ids_args)
-        txn.execute(sql, args)
+        txn.execute(
+            sql,
+            receipt_types_args + [user_id, rotated_upto_stream_ordering],
+        )
+        for room_id, _thread_id in txn:
+            # Again, we ignore any stale rooms.
+            if room_id not in stale_room_ids:
+                # For event push actions it is one notification per row.
+                room_to_count[room_id] += 1
+
+        # Step 3, if we have stale rooms then we need to recalculate the counts
+        # from `event_push_actions`. Again, this is basically the same query as
+        # above except without a lower bound on stream ordering and only against
+        # a specific set of rooms.
+        if stale_room_ids:
+            room_id_clause, room_id_args = make_in_list_sql_clause(
+                self.database_engine,
+                "e.room_id",
+                stale_room_ids,
+            )
 
-        for room_id, notif_count in txn:
-            room_to_count[room_id] += notif_count
+            sql = f"""
+                SELECT room_id, thread_id
+                FROM (
+                    SELECT e.room_id, e.stream_ordering, e.thread_id,
+                        ev.stream_ordering AS receipt_stream_ordering
+                    FROM event_push_actions AS e
+                    INNER JOIN local_current_membership USING (user_id, room_id)
+                    LEFT JOIN receipts_linearized AS r ON (
+                        e.user_id = r.user_id
+                        AND e.room_id = r.room_id
+                        AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+                        AND {receipt_types_clause}
+                    )
+                    LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+                    WHERE e.user_id = ? and notif > 0
+                        AND {room_id_clause}
+                ) AS es
+                GROUP BY room_id, stream_ordering, thread_id
+                HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0)
+            """
+            txn.execute(
+                sql,
+                receipt_types_args + [user_id] + room_id_args,
+            )
+            for room_id, _ in txn:
+                room_to_count[room_id] += 1
 
         return room_to_count
 |