summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/crypto/keyring.py107
-rw-r--r--synapse/federation/sender/__init__.py2
-rw-r--r--synapse/push/push_tools.py28
-rw-r--r--synapse/rest/client/versions.py1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py149
5 files changed, 203 insertions, 84 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index ed15f88350..69310d9035 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,7 +14,6 @@
 
 import abc
 import logging
-import urllib
 from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
 
 import attr
@@ -813,31 +812,27 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
         results = {}
 
-        async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
+        async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None:
             server_name = key_to_fetch_item.server_name
-            key_ids = key_to_fetch_item.key_ids
 
             try:
-                keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
+                keys = await self.get_server_verify_keys_v2_direct(server_name)
                 results[server_name] = keys
             except KeyLookupError as e:
-                logger.warning(
-                    "Error looking up keys %s from %s: %s", key_ids, server_name, e
-                )
+                logger.warning("Error looking up keys from %s: %s", server_name, e)
             except Exception:
-                logger.exception("Error getting keys %s from %s", key_ids, server_name)
+                logger.exception("Error getting keys from %s", server_name)
 
-        await yieldable_gather_results(get_key, keys_to_fetch)
+        await yieldable_gather_results(get_keys, keys_to_fetch)
         return results
 
-    async def get_server_verify_key_v2_direct(
-        self, server_name: str, key_ids: Iterable[str]
+    async def get_server_verify_keys_v2_direct(
+        self, server_name: str
     ) -> Dict[str, FetchKeyResult]:
         """
 
         Args:
-            server_name:
-            key_ids:
+            server_name: Server to request keys from
 
         Returns:
             Map from key ID to lookup result
@@ -845,57 +840,41 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         Raises:
             KeyLookupError if there was a problem making the lookup
         """
-        keys: Dict[str, FetchKeyResult] = {}
-
-        for requested_key_id in key_ids:
-            # we may have found this key as a side-effect of asking for another.
-            if requested_key_id in keys:
-                continue
-
-            time_now_ms = self.clock.time_msec()
-            try:
-                response = await self.client.get_json(
-                    destination=server_name,
-                    path="/_matrix/key/v2/server/"
-                    + urllib.parse.quote(requested_key_id, safe=""),
-                    ignore_backoff=True,
-                    # we only give the remote server 10s to respond. It should be an
-                    # easy request to handle, so if it doesn't reply within 10s, it's
-                    # probably not going to.
-                    #
-                    # Furthermore, when we are acting as a notary server, we cannot
-                    # wait all day for all of the origin servers, as the requesting
-                    # server will otherwise time out before we can respond.
-                    #
-                    # (Note that get_json may make 4 attempts, so this can still take
-                    # almost 45 seconds to fetch the headers, plus up to another 60s to
-                    # read the response).
-                    timeout=10000,
-                )
-            except (NotRetryingDestination, RequestSendFailed) as e:
-                # these both have str() representations which we can't really improve
-                # upon
-                raise KeyLookupError(str(e))
-            except HttpResponseException as e:
-                raise KeyLookupError("Remote server returned an error: %s" % (e,))
-
-            assert isinstance(response, dict)
-            if response["server_name"] != server_name:
-                raise KeyLookupError(
-                    "Expected a response for server %r not %r"
-                    % (server_name, response["server_name"])
-                )
-
-            response_keys = await self.process_v2_response(
-                from_server=server_name,
-                response_json=response,
-                time_added_ms=time_now_ms,
+        time_now_ms = self.clock.time_msec()
+        try:
+            response = await self.client.get_json(
+                destination=server_name,
+                path="/_matrix/key/v2/server",
+                ignore_backoff=True,
+                # we only give the remote server 10s to respond. It should be an
+                # easy request to handle, so if it doesn't reply within 10s, it's
+                # probably not going to.
+                #
+                # Furthermore, when we are acting as a notary server, we cannot
+                # wait all day for all of the origin servers, as the requesting
+                # server will otherwise time out before we can respond.
+                #
+                # (Note that get_json may make 4 attempts, so this can still take
+                # almost 45 seconds to fetch the headers, plus up to another 60s to
+                # read the response).
+                timeout=10000,
             )
-            await self.store.store_server_verify_keys(
-                server_name,
-                time_now_ms,
-                ((server_name, key_id, key) for key_id, key in response_keys.items()),
+        except (NotRetryingDestination, RequestSendFailed) as e:
+            # these both have str() representations which we can't really improve
+            # upon
+            raise KeyLookupError(str(e))
+        except HttpResponseException as e:
+            raise KeyLookupError("Remote server returned an error: %s" % (e,))
+
+        assert isinstance(response, dict)
+        if response["server_name"] != server_name:
+            raise KeyLookupError(
+                "Expected a response for server %r not %r"
+                % (server_name, response["server_name"])
             )
-            keys.update(response_keys)
 
-        return keys
+        return await self.process_v2_response(
+            from_server=server_name,
+            response_json=response,
+            time_added_ms=time_now_ms,
+        )
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index fc1d8c88a7..30ebd62883 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -647,7 +647,7 @@ class FederationSender(AbstractFederationSender):
         room_id = receipt.room_id
 
         # Work out which remote servers should be poked and poke them.
-        domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
+        domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
             room_id
         )
         domains = [
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index edeba27a45..7ee07e4bee 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,7 +17,6 @@ from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
-from synapse.util.async_helpers import concurrently_execute
 
 
 async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
 
     badge = len(invites)
 
-    room_notifs = []
-
-    async def get_room_unread_count(room_id: str) -> None:
-        room_notifs.append(
-            await store.get_unread_event_push_actions_by_room_for_user(
-                room_id,
-                user_id,
-            )
-        )
-
-    await concurrently_execute(get_room_unread_count, joins, 10)
-
-    for notifs in room_notifs:
-        # Combine the counts from all the threads.
-        notify_count = notifs.main_timeline.notify_count + sum(
-            n.notify_count for n in notifs.threads.values()
-        )
+    room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
+    for room_id, notify_count in room_to_count.items():
+        # room_to_count may include rooms which the user has left,
+        # ignore those.
+        if room_id not in joins:
+            continue
 
         if notify_count == 0:
             continue
@@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
             # return one badge count per conversation
             badge += 1
         else:
-            # increment the badge count by the number of unread messages in the room
+            # Increase badge by number of notifications in room
+            # NOTE: this includes threaded and unthreaded notifications.
             badge += notify_count
+
     return badge
 
 
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 3c0a90010b..e19c0946c0 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -77,6 +77,7 @@ class VersionsRestServlet(RestServlet):
                     "v1.2",
                     "v1.3",
                     "v1.4",
+                    "v1.5",
                 ],
                 # as per MSC1497:
                 "unstable_features": {
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b283ab0f9c..7ebe34f773 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,6 +74,7 @@ receipt.
 """
 
 import logging
+from collections import defaultdict
 from typing import (
     TYPE_CHECKING,
     Collection,
@@ -95,6 +96,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    PostgresEngine,
 )
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -463,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return result
 
+    async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
+        """Get the notification count by room for a user. Only considers notifications,
+        not highlight or unread counts, and threads are currently aggregated under their room.
+
+        This function is intentionally not cached because it is called to calculate the
+        unread badge for push notifications and thus the result is expected to change.
+
+        Note that this function assumes the user is a member of the room. Because
+        summary rows are not removed when a user leaves a room, the caller must
+        filter out those results from the result.
+
+        Returns:
+            A map of room ID to notification counts for the given user.
+        """
+        return await self.db_pool.runInteraction(
+            "get_unread_counts_by_room_for_user",
+            self._get_unread_counts_by_room_for_user_txn,
+            user_id,
+        )
+
+    def _get_unread_counts_by_room_for_user_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> Dict[str, int]:
+        receipt_types_clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "receipt_type",
+            (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+        )
+        args.extend([user_id, user_id])
+
+        receipts_cte = f"""
+            WITH all_receipts AS (
+                SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
+                FROM receipts_linearized
+                LEFT JOIN events USING (room_id, event_id)
+                WHERE
+                    {receipt_types_clause}
+                    AND user_id = ?
+                GROUP BY room_id, thread_id
+            )
+        """
+
+        receipts_joins = """
+            LEFT JOIN (
+                SELECT room_id, thread_id,
+                max_receipt_stream_ordering AS threaded_receipt_stream_ordering
+                FROM all_receipts
+                WHERE thread_id IS NOT NULL
+            ) AS threaded_receipts USING (room_id, thread_id)
+            LEFT JOIN (
+                SELECT room_id, thread_id,
+                max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
+                FROM all_receipts
+                WHERE thread_id IS NULL
+            ) AS unthreaded_receipts USING (room_id)
+        """
+
+        # First get summary counts by room / thread for the user. We use the max receipt
+        # stream ordering of both threaded & unthreaded receipts to compare against the
+        # summary table.
+        #
+        # PostgreSQL and SQLite differ in comparing scalar numerics.
+        if isinstance(self.database_engine, PostgresEngine):
+            # GREATEST ignores NULLs.
+            max_clause = """GREATEST(
+                threaded_receipt_stream_ordering,
+                unthreaded_receipt_stream_ordering
+            )"""
+        else:
+            # MAX returns NULL if any are NULL, so COALESCE to 0 first.
+            max_clause = """MAX(
+                COALESCE(threaded_receipt_stream_ordering, 0),
+                COALESCE(unthreaded_receipt_stream_ordering, 0)
+            )"""
+
+        sql = f"""
+            {receipts_cte}
+            SELECT eps.room_id, eps.thread_id, notif_count
+            FROM event_push_summary AS eps
+            {receipts_joins}
+            WHERE user_id = ?
+                AND notif_count != 0
+                AND (
+                    (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
+                    OR last_receipt_stream_ordering = {max_clause}
+                )
+        """
+        txn.execute(sql, args)
+
+        seen_thread_ids = set()
+        room_to_count: Dict[str, int] = defaultdict(int)
+
+        for room_id, thread_id, notif_count in txn:
+            room_to_count[room_id] += notif_count
+            seen_thread_ids.add(thread_id)
+
+        # Now get any event push actions that haven't been rotated using the same OR
+        # join and filter by receipt and event push summary rotated up to stream ordering.
+        sql = f"""
+            {receipts_cte}
+            SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+            FROM event_push_actions AS epa
+            {receipts_joins}
+            WHERE user_id = ?
+                AND epa.notif = 1
+                AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
+                AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+                AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+            GROUP BY epa.room_id, epa.thread_id
+        """
+        txn.execute(sql, args)
+
+        for room_id, thread_id, notif_count in txn:
+            # Note: only count push actions we have valid summaries for with up to date receipt.
+            if thread_id not in seen_thread_ids:
+                continue
+            room_to_count[room_id] += notif_count
+
+        thread_id_clause, thread_ids_args = make_in_list_sql_clause(
+            self.database_engine, "epa.thread_id", seen_thread_ids
+        )
+
+        # Finally re-check event_push_actions for any rooms not in the summary, ignoring
+        # the rotated up-to position. This handles the case where a read receipt has arrived
+        # but not been rotated meaning the summary table is out of date, so we go back to
+        # the push actions table.
+        sql = f"""
+            {receipts_cte}
+            SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+            FROM event_push_actions AS epa
+            {receipts_joins}
+            WHERE user_id = ?
+            AND NOT {thread_id_clause}
+            AND epa.notif = 1
+            AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+            AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+            GROUP BY epa.room_id
+        """
+
+        args.extend(thread_ids_args)
+        txn.execute(sql, args)
+
+        for room_id, notif_count in txn:
+            room_to_count[room_id] += notif_count
+
+        return room_to_count
+
     @cached(tree=True, max_entries=5000, iterable=True)
     async def get_unread_event_push_actions_by_room_for_user(
         self,