diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index ed15f88350..69310d9035 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,7 +14,6 @@
import abc
import logging
-import urllib
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
@@ -813,31 +812,27 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {}
- async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
+ async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None:
server_name = key_to_fetch_item.server_name
- key_ids = key_to_fetch_item.key_ids
try:
- keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
+ keys = await self.get_server_verify_keys_v2_direct(server_name)
results[server_name] = keys
except KeyLookupError as e:
- logger.warning(
- "Error looking up keys %s from %s: %s", key_ids, server_name, e
- )
+ logger.warning("Error looking up keys from %s: %s", server_name, e)
except Exception:
- logger.exception("Error getting keys %s from %s", key_ids, server_name)
+ logger.exception("Error getting keys from %s", server_name)
- await yieldable_gather_results(get_key, keys_to_fetch)
+ await yieldable_gather_results(get_keys, keys_to_fetch)
return results
- async def get_server_verify_key_v2_direct(
- self, server_name: str, key_ids: Iterable[str]
+ async def get_server_verify_keys_v2_direct(
+ self, server_name: str
) -> Dict[str, FetchKeyResult]:
"""
Args:
- server_name:
- key_ids:
+ server_name: Server to request keys from
Returns:
Map from key ID to lookup result
@@ -845,57 +840,41 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
Raises:
KeyLookupError if there was a problem making the lookup
"""
- keys: Dict[str, FetchKeyResult] = {}
-
- for requested_key_id in key_ids:
- # we may have found this key as a side-effect of asking for another.
- if requested_key_id in keys:
- continue
-
- time_now_ms = self.clock.time_msec()
- try:
- response = await self.client.get_json(
- destination=server_name,
- path="/_matrix/key/v2/server/"
- + urllib.parse.quote(requested_key_id, safe=""),
- ignore_backoff=True,
- # we only give the remote server 10s to respond. It should be an
- # easy request to handle, so if it doesn't reply within 10s, it's
- # probably not going to.
- #
- # Furthermore, when we are acting as a notary server, we cannot
- # wait all day for all of the origin servers, as the requesting
- # server will otherwise time out before we can respond.
- #
- # (Note that get_json may make 4 attempts, so this can still take
- # almost 45 seconds to fetch the headers, plus up to another 60s to
- # read the response).
- timeout=10000,
- )
- except (NotRetryingDestination, RequestSendFailed) as e:
- # these both have str() representations which we can't really improve
- # upon
- raise KeyLookupError(str(e))
- except HttpResponseException as e:
- raise KeyLookupError("Remote server returned an error: %s" % (e,))
-
- assert isinstance(response, dict)
- if response["server_name"] != server_name:
- raise KeyLookupError(
- "Expected a response for server %r not %r"
- % (server_name, response["server_name"])
- )
-
- response_keys = await self.process_v2_response(
- from_server=server_name,
- response_json=response,
- time_added_ms=time_now_ms,
+ time_now_ms = self.clock.time_msec()
+ try:
+ response = await self.client.get_json(
+ destination=server_name,
+ path="/_matrix/key/v2/server",
+ ignore_backoff=True,
+ # we only give the remote server 10s to respond. It should be an
+ # easy request to handle, so if it doesn't reply within 10s, it's
+ # probably not going to.
+ #
+ # Furthermore, when we are acting as a notary server, we cannot
+ # wait all day for all of the origin servers, as the requesting
+ # server will otherwise time out before we can respond.
+ #
+ # (Note that get_json may make 4 attempts, so this can still take
+ # almost 45 seconds to fetch the headers, plus up to another 60s to
+ # read the response).
+ timeout=10000,
)
- await self.store.store_server_verify_keys(
- server_name,
- time_now_ms,
- ((server_name, key_id, key) for key_id, key in response_keys.items()),
+ except (NotRetryingDestination, RequestSendFailed) as e:
+ # these both have str() representations which we can't really improve
+ # upon
+ raise KeyLookupError(str(e))
+ except HttpResponseException as e:
+ raise KeyLookupError("Remote server returned an error: %s" % (e,))
+
+ assert isinstance(response, dict)
+ if response["server_name"] != server_name:
+ raise KeyLookupError(
+ "Expected a response for server %r not %r"
+ % (server_name, response["server_name"])
)
- keys.update(response_keys)
- return keys
+ return await self.process_v2_response(
+ from_server=server_name,
+ response_json=response,
+ time_added_ms=time_now_ms,
+ )
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index fc1d8c88a7..30ebd62883 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -647,7 +647,7 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
- domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
+ domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
domains = [
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index edeba27a45..7ee07e4bee 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,7 +17,6 @@ from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
-from synapse.util.async_helpers import concurrently_execute
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites)
- room_notifs = []
-
- async def get_room_unread_count(room_id: str) -> None:
- room_notifs.append(
- await store.get_unread_event_push_actions_by_room_for_user(
- room_id,
- user_id,
- )
- )
-
- await concurrently_execute(get_room_unread_count, joins, 10)
-
- for notifs in room_notifs:
- # Combine the counts from all the threads.
- notify_count = notifs.main_timeline.notify_count + sum(
- n.notify_count for n in notifs.threads.values()
- )
+ room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
+ for room_id, notify_count in room_to_count.items():
+ # room_to_count may include rooms which the user has left,
+ # ignore those.
+ if room_id not in joins:
+ continue
if notify_count == 0:
continue
@@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
# return one badge count per conversation
badge += 1
else:
- # increment the badge count by the number of unread messages in the room
+ # Increase badge by number of notifications in room
+ # NOTE: this includes threaded and unthreaded notifications.
badge += notify_count
+
return badge
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 3c0a90010b..e19c0946c0 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -77,6 +77,7 @@ class VersionsRestServlet(RestServlet):
"v1.2",
"v1.3",
"v1.4",
+ "v1.5",
],
# as per MSC1497:
"unstable_features": {
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b283ab0f9c..7ebe34f773 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,6 +74,7 @@ receipt.
"""
import logging
+from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -95,6 +96,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -463,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
+ async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
+ """Get the notification count by room for a user. Only considers notifications,
+ not highlight or unread counts, and threads are currently aggregated under their room.
+
+ This function is intentionally not cached because it is called to calculate the
+ unread badge for push notifications and thus the result is expected to change.
+
+ Note that this function assumes the user is a member of the room. Because
+ summary rows are not removed when a user leaves a room, the caller must
+ filter out those results from the result.
+
+ Returns:
+ A map of room ID to notification counts for the given user.
+ """
+ return await self.db_pool.runInteraction(
+ "get_unread_counts_by_room_for_user",
+ self._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+
+ def _get_unread_counts_by_room_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Dict[str, int]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ )
+ args.extend([user_id, user_id])
+
+ receipts_cte = f"""
+ WITH all_receipts AS (
+ SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id, thread_id
+ )
+ """
+
+ receipts_joins = """
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS threaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NOT NULL
+ ) AS threaded_receipts USING (room_id, thread_id)
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NULL
+ ) AS unthreaded_receipts USING (room_id)
+ """
+
+ # First get summary counts by room / thread for the user. We use the max receipt
+ # stream ordering of both threaded & unthreaded receipts to compare against the
+ # summary table.
+ #
+ # PostgreSQL and SQLite differ in comparing scalar numerics.
+ if isinstance(self.database_engine, PostgresEngine):
+ # GREATEST ignores NULLs.
+ max_clause = """GREATEST(
+ threaded_receipt_stream_ordering,
+ unthreaded_receipt_stream_ordering
+ )"""
+ else:
+ # MAX returns NULL if any are NULL, so COALESCE to 0 first.
+ max_clause = """MAX(
+ COALESCE(threaded_receipt_stream_ordering, 0),
+ COALESCE(unthreaded_receipt_stream_ordering, 0)
+ )"""
+
+ sql = f"""
+ {receipts_cte}
+ SELECT eps.room_id, eps.thread_id, notif_count
+ FROM event_push_summary AS eps
+ {receipts_joins}
+ WHERE user_id = ?
+ AND notif_count != 0
+ AND (
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
+ OR last_receipt_stream_ordering = {max_clause}
+ )
+ """
+ txn.execute(sql, args)
+
+ seen_thread_ids = set()
+ room_to_count: Dict[str, int] = defaultdict(int)
+
+ for room_id, thread_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+ seen_thread_ids.add(thread_id)
+
+ # Now get any event push actions that haven't been rotated using the same OR
+ # join and filter by receipt and event push summary rotated up to stream ordering.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND epa.notif = 1
+ AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id, epa.thread_id
+ """
+ txn.execute(sql, args)
+
+ for room_id, thread_id, notif_count in txn:
+ # Note: only count push actions we have valid summaries for with up to date receipt.
+ if thread_id not in seen_thread_ids:
+ continue
+ room_to_count[room_id] += notif_count
+
+ thread_id_clause, thread_ids_args = make_in_list_sql_clause(
+ self.database_engine, "epa.thread_id", seen_thread_ids
+ )
+
+ # Finally re-check event_push_actions for any rooms not in the summary, ignoring
+ # the rotated up-to position. This handles the case where a read receipt has arrived
+ # but not been rotated meaning the summary table is out of date, so we go back to
+ # the push actions table.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND NOT {thread_id_clause}
+ AND epa.notif = 1
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id
+ """
+
+ args.extend(thread_ids_args)
+ txn.execute(sql, args)
+
+ for room_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+
+ return room_to_count
+
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
|