diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 07ca16c94c..8f9cdbddbb 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -299,6 +299,19 @@ def _parse_oidc_config_dict(
config_path + ("client_secret",),
)
+ # If no client secret is specified then the auth method must be None
+ client_auth_method = oidc_config.get("client_auth_method")
+ if client_secret is None and client_secret_jwt_key is None:
+ if client_auth_method is None:
+ client_auth_method = "none"
+ elif client_auth_method != "none":
+ raise ConfigError(
+ "No 'client_secret' is set in OIDC config, and 'client_auth_method' is not set to 'none'"
+ )
+
+ if client_auth_method is None:
+ client_auth_method = "client_secret_basic"
+
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
@@ -309,7 +322,7 @@ def _parse_oidc_config_dict(
client_id=oidc_config["client_id"],
client_secret=client_secret,
client_secret_jwt_key=client_secret_jwt_key,
- client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
+ client_auth_method=client_auth_method,
pkce_method=oidc_config.get("pkce_method", "auto"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 882be905db..12837429b9 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -94,7 +94,7 @@ from synapse.types import (
)
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched
+from synapse.util.iterutils import batch_iter, partition, sorted_topologically
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -1141,16 +1141,8 @@ class FederationEventHandler:
partial_state_flags = await self._store.get_partial_state_events(seen)
partial_state = any(partial_state_flags.values())
- # Get the state of the events we know about
- ours = await self._state_storage_controller.get_state_groups_ids(
- room_id, seen, await_full_state=False
- )
-
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps: List[StateMap[str]] = list(ours.values())
-
- # we don't need this any more, let's delete it.
- del ours
+ state_maps: List[StateMap[str]] = []
# Ask the remote server for the states we don't
# know about
@@ -1169,6 +1161,17 @@ class FederationEventHandler:
state_maps.append(remote_state_map)
+ # Get the state of the events we know about. We do this *after*
+ # trying to fetch missing state over federation as that might fail
+ # and then we can skip loading the local state.
+ ours = await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen, await_full_state=False
+ )
+ state_maps.extend(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
+
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
@@ -1678,57 +1681,36 @@ class FederationEventHandler:
# We need to persist an event's auth events before the event.
auth_graph = {
- ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map]
+ ev.event_id: [e_id for e_id in ev.auth_event_ids() if e_id in event_map]
for ev in event_map.values()
}
- for roots in sorted_topologically_batched(event_map.values(), auth_graph):
- if not roots:
- # if *none* of the remaining events are ready, that means
- # we have a loop. This either means a bug in our logic, or that
- # somebody has managed to create a loop (which requires finding a
- # hash collision in room v2 and later).
- logger.warning(
- "Loop found in auth events while fetching missing state/auth "
- "events: %s",
- shortstr(event_map.keys()),
- )
- return
-
- logger.info(
- "Persisting %i of %i remaining outliers: %s",
- len(roots),
- len(event_map),
- shortstr(e.event_id for e in roots),
- )
-
- await self._auth_and_persist_outliers_inner(room_id, roots)
-
- async def _auth_and_persist_outliers_inner(
- self, room_id: str, fetched_events: Collection[EventBase]
- ) -> None:
- """Helper for _auth_and_persist_outliers
-
- Persists a batch of events where we have (theoretically) already persisted all
- of their auth events.
-
- Marks the events as outliers, auths them, persists them to the database, and,
- where appropriate (eg, an invite), awakes the notifier.
+ sorted_auth_event_ids = sorted_topologically(event_map.keys(), auth_graph)
+ sorted_auth_events = [event_map[e_id] for e_id in sorted_auth_event_ids]
+ logger.info(
+ "Persisting %i remaining outliers: %s",
+ len(sorted_auth_events),
+ shortstr(e.event_id for e in sorted_auth_events),
+ )
- Params:
- origin: where the events came from
- room_id: the room that the events are meant to be in (though this has
- not yet been checked)
- fetched_events: the events to persist
- """
# get all the auth events for all the events in this batch. By now, they should
# have been persisted.
- auth_events = {
- aid for event in fetched_events for aid in event.auth_event_ids()
+ auth_event_ids = {
+ aid for event in sorted_auth_events for aid in event.auth_event_ids()
}
- persisted_events = await self._store.get_events(
- auth_events,
- allow_rejected=True,
- )
+ auth_map = {
+ ev.event_id: ev
+ for ev in sorted_auth_events
+ if ev.event_id in auth_event_ids
+ }
+
+ missing_events = auth_event_ids.difference(auth_map)
+ if missing_events:
+ persisted_events = await self._store.get_events(
+ missing_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ )
+ auth_map.update(persisted_events)
events_and_contexts_to_persist: List[Tuple[EventBase, EventContext]] = []
@@ -1736,7 +1718,7 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
auth = []
for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id)
+ ae = auth_map.get(auth_event_id)
if not ae:
# the fact we can't find the auth event doesn't mean it doesn't
# exist, which means it is premature to reject `event`. Instead we
@@ -1755,7 +1737,9 @@ class FederationEventHandler:
context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(event)
- await check_state_independent_auth_rules(self._store, event)
+ await check_state_independent_auth_rules(
+ self._store, event, batched_auth_events=auth_map
+ )
check_state_dependent_auth_rules(event, auth)
except AuthError as e:
logger.warning("Rejecting %r because %s", event, e)
@@ -1772,7 +1756,7 @@ class FederationEventHandler:
events_and_contexts_to_persist.append((event, context))
- for event in fetched_events:
+ for event in sorted_auth_events:
await prep(event)
await self.persist_events_and_notify(
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index a534f5f280..78bcac1429 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -44,6 +44,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.config.ratelimiting import RatelimitSettings
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StrCollection
+from synapse.types.state import StateFilter
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -546,7 +547,16 @@ class RoomSummaryHandler:
Returns:
True if the room is accessible to the requesting user or server.
"""
- state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
+ event_types = [
+ (EventTypes.JoinRules, ""),
+ (EventTypes.RoomHistoryVisibility, ""),
+ ]
+ if requester:
+ event_types.append((EventTypes.Member, requester))
+
+ state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id, state_filter=StateFilter.from_types(event_types)
+ )
# If there's no state for the room, it isn't known.
if not state_ids:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0385c04bc2..2e10035772 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -583,10 +583,11 @@ class SyncHandler:
# `recents`, so partial state is only a problem when a membership
# event turns up in `recents` but has not made it into the current
# state.
- current_state_ids_map = (
- await self.store.get_partial_current_state_ids(room_id)
+ current_state_ids = (
+ await self.store.check_if_events_in_current_state(
+ {e.event_id for e in recents if e.is_state()}
+ )
)
- current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
self._storage_controllers,
@@ -667,10 +668,11 @@ class SyncHandler:
# `loaded_recents`, so partial state is only a problem when a
# membership event turns up in `loaded_recents` but has not made it
# into the current state.
- current_state_ids_map = (
- await self.store.get_partial_current_state_ids(room_id)
+ current_state_ids = (
+ await self.store.check_if_events_in_current_state(
+ {e.event_id for e in loaded_recents if e.is_state()}
+ )
)
- current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
self._storage_controllers,
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 03ce0b4dc6..cce9583fa7 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -28,17 +28,11 @@ from synapse.storage.databases.main import DataStore
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
- joins = await store.get_rooms_for_user(user_id)
badge = len(invites)
room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
- for room_id, notify_count in room_to_count.items():
- # room_to_count may include rooms which the user has left,
- # ignore those.
- if room_id not in joins:
- continue
-
+ for _room_id, notify_count in room_to_count.items():
if notify_count == 0:
continue
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4bbcf7199c..5a1a3e8e65 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
* The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
"""
- (
- user_id_device_id_to_messages,
- last_processed_stream_id,
- ) = await self._get_device_messages(
- user_ids=[user_id],
- device_id=device_id,
- from_stream_id=from_stream_id,
- to_stream_id=to_stream_id,
- limit=limit,
- )
-
- if not user_id_device_id_to_messages:
+ if not self._device_inbox_stream_cache.has_entity_changed(
+ user_id, from_stream_id
+ ):
# There were no messages!
return [], to_stream_id
- # Extract the messages, no need to return the user and device ID again
- to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+ def get_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+ sql = """
+ SELECT stream_id, message_json FROM device_inbox
+ WHERE user_id = ? AND device_id = ?
+ AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit))
+
+ # Create and fill a dictionary of (user ID, device ID) -> list of messages
+ # intended for each device.
+ last_processed_stream_pos = to_stream_id
+ to_device_messages: List[JsonDict] = []
+ rowcount = 0
+ for row in txn:
+ rowcount += 1
+
+ last_processed_stream_pos = row[0]
+ message_dict = db_to_json(row[1])
+
+ # Store the device details
+ to_device_messages.append(message_dict)
- return to_device_messages, last_processed_stream_id
+ # start a new span for each message, so that we can tag each separately
+ with start_active_span("get_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+
+ if rowcount == limit:
+ # We ended up bumping up against the message limit. There may be more messages
+ # to retrieve. Return what we have, as well as the last stream position that
+ # was processed.
+ #
+ # The caller is expected to set this as the lower (exclusive) bound
+ # for the next query of this device.
+ return to_device_messages, last_processed_stream_pos
+
+ # The limit was not reached, thus we know that recipient_device_to_messages
+ # contains all to-device messages for the given device and stream id range.
+ #
+ # We return to_stream_id, which the caller should then provide as the lower
+ # (exclusive) bound on the next query of this device.
+ return to_device_messages, to_stream_id
+
+ return await self.db_pool.runInteraction(
+ "get_messages_for_device", get_device_messages_txn
+ )
async def _get_device_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
- device_id: Optional[str] = None,
- limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
@@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
- device_id: A device ID to query to-device messages for. If not provided, to-device
- messages from all device IDs for the given user IDs will be queried. May not be
- provided if `user_ids` contains more than one entry.
- limit: The maximum number of to-device messages to return. Can only be used when
- passing a single user ID / device ID tuple.
+
Returns:
A tuple containing:
@@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id
- # Prevent a query for one user's device also retrieving another user's device with
- # the same device ID (device IDs are not unique across users).
- if len(user_ids) > 1 and device_id is not None:
- raise AssertionError(
- "Programming error: 'device_id' cannot be supplied to "
- "_get_device_messages when >1 user_id has been provided"
- )
-
- # A limit can only be applied when querying for a single user ID / device ID tuple.
- # See the docstring of this function for more details.
- if limit is not None and device_id is None:
- raise AssertionError(
- "Programming error: _get_device_messages was passed 'limit' "
- "without a specific user_id/device_id"
- )
-
user_ids_to_query: Set[str] = set()
- device_ids_to_query: Set[str] = set()
-
- # Note that a device ID could be an empty str
- if device_id is not None:
- # If a device ID was passed, use it to filter results.
- # Otherwise, device IDs will be derived from the given collection of user IDs.
- device_ids_to_query.add(device_id)
# Determine which users have devices with pending messages
for user_id in user_ids:
@@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
- if not device_ids_to_query:
- user_device_dicts = cast(
- List[Tuple[str]],
- self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- column="user_id",
- iterable=user_ids_to_query,
- keyvalues={"hidden": False},
- retcols=("device_id",),
- ),
- )
- device_ids_to_query.update({row[0] for row in user_device_dicts})
+ user_device_dicts = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"hidden": False},
+ retcols=("device_id",),
+ ),
+ )
+
+ device_ids_to_query = {row[0] for row in user_device_dicts}
if not device_ids_to_query:
# We've ended up with no devices to query.
@@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
to_stream_id,
)
- # If a limit was provided, limit the data retrieved from the database
- if limit is not None:
- sql += "LIMIT ?"
- sql_args += (limit,)
-
txn.execute(sql, sql_args)
# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
- last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
rowcount = 0
for row in txn:
rowcount += 1
- last_processed_stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
@@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
)
- if limit is not None and rowcount == limit:
- # We ended up bumping up against the message limit. There may be more messages
- # to retrieve. Return what we have, as well as the last stream position that
- # was processed.
- #
- # The caller is expected to set this as the lower (exclusive) bound
- # for the next query of this device.
- return recipient_device_to_messages, last_processed_stream_pos
-
- # The limit was not reached, thus we know that recipient_device_to_messages
- # contains all to-device messages for the given device and stream id range.
- #
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 38029710db..d3859014b6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1796,7 +1796,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_ids: The IDs of the devices to delete
"""
- def _delete_devices_txn(txn: LoggingTransaction) -> None:
+ def _delete_devices_txn(txn: LoggingTransaction, device_ids: List[str]) -> None:
self.db_pool.simple_delete_many_txn(
txn,
table="devices",
@@ -1813,7 +1813,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
- await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
+ for batch in batch_iter(device_ids, 100):
+ await self.db_pool.runInteraction(
+ "delete_devices", _delete_devices_txn, batch
+ )
+
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 650b8c8135..6d4e2942ea 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -357,10 +357,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
This function is intentionally not cached because it is called to calculate the
unread badge for push notifications and thus the result is expected to change.
- Note that this function assumes the user is a member of the room. Because
- summary rows are not removed when a user leaves a room, the caller must
- filter out those results from the result.
-
Returns:
A map of room ID to notification counts for the given user.
"""
@@ -373,127 +369,170 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_unread_counts_by_room_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
- receipt_types_clause, args = make_in_list_sql_clause(
+ # To get the badge count of all rooms we need to make three queries:
+ # 1. Fetch all counts from `event_push_summary`, discarding any stale
+ # rooms.
+ # 2. Fetch all notifications from `event_push_actions` that haven't
+ # been rotated yet.
+ # 3. Fetch all notifications from `event_push_actions` for the stale
+ # rooms.
+ #
+ # The "stale room" scenario generally happens when there is a new read
+ # receipt that hasn't yet been processed to update the
+ # `event_push_summary` table. When that happens we ignore the
+ # `event_push_summary` table for that room and calculate the count
+ # manually from `event_push_actions`.
+
+ # We need to only take into account read receipts of these types.
+ receipt_types_clause, receipt_types_args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
- args.extend([user_id, user_id])
-
- receipts_cte = f"""
- WITH all_receipts AS (
- SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
- FROM receipts_linearized
- LEFT JOIN events USING (room_id, event_id)
- WHERE
- {receipt_types_clause}
- AND user_id = ?
- GROUP BY room_id, thread_id
- )
- """
-
- receipts_joins = """
- LEFT JOIN (
- SELECT room_id, thread_id,
- max_receipt_stream_ordering AS threaded_receipt_stream_ordering
- FROM all_receipts
- WHERE thread_id IS NOT NULL
- ) AS threaded_receipts USING (room_id, thread_id)
- LEFT JOIN (
- SELECT room_id, thread_id,
- max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
- FROM all_receipts
- WHERE thread_id IS NULL
- ) AS unthreaded_receipts USING (room_id)
- """
-
- # First get summary counts by room / thread for the user. We use the max receipt
- # stream ordering of both threaded & unthreaded receipts to compare against the
- # summary table.
- #
- # PostgreSQL and SQLite differ in comparing scalar numerics.
- if isinstance(self.database_engine, PostgresEngine):
- # GREATEST ignores NULLs.
- max_clause = """GREATEST(
- threaded_receipt_stream_ordering,
- unthreaded_receipt_stream_ordering
- )"""
- else:
- # MAX returns NULL if any are NULL, so COALESCE to 0 first.
- max_clause = """MAX(
- COALESCE(threaded_receipt_stream_ordering, 0),
- COALESCE(unthreaded_receipt_stream_ordering, 0)
- )"""
+ # Step 1, fetch all counts from `event_push_summary` for the user. This
+ # is slightly convoluted as we also need to pull out the stream ordering
+ # of the most recent receipt of the user in the room (either a thread
+ # aware receipt or thread unaware receipt) in order to determine
+ # whether the row in `event_push_summary` is stale. Hence the outer
+ # GROUP BY and odd join condition against `receipts_linearized`.
sql = f"""
- {receipts_cte}
- SELECT eps.room_id, eps.thread_id, notif_count
- FROM event_push_summary AS eps
- {receipts_joins}
- WHERE user_id = ?
- AND notif_count != 0
- AND (
- (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
- OR last_receipt_stream_ordering = {max_clause}
+ SELECT room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering,
+ MAX(receipt_stream_ordering)
+ FROM (
+ SELECT e.room_id, notif_count, e.stream_ordering, e.thread_id, last_receipt_stream_ordering,
+ ev.stream_ordering AS receipt_stream_ordering
+ FROM event_push_summary AS e
+ INNER JOIN local_current_membership USING (user_id, room_id)
+ LEFT JOIN receipts_linearized AS r ON (
+ e.user_id = r.user_id
+ AND e.room_id = r.room_id
+ AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+ AND {receipt_types_clause}
)
+ LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+ WHERE e.user_id = ? and notif_count > 0
+ ) AS es
+ GROUP BY room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering
"""
- txn.execute(sql, args)
-
- seen_thread_ids = set()
- room_to_count: Dict[str, int] = defaultdict(int)
- for room_id, thread_id, notif_count in txn:
- room_to_count[room_id] += notif_count
- seen_thread_ids.add(thread_id)
+ txn.execute(
+ sql,
+ receipt_types_args
+ + [
+ user_id,
+ ],
+ )
- # Now get any event push actions that haven't been rotated using the same OR
- # join and filter by receipt and event push summary rotated up to stream ordering.
- sql = f"""
- {receipts_cte}
- SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
- FROM event_push_actions AS epa
- {receipts_joins}
- WHERE user_id = ?
- AND epa.notif = 1
- AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
- AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
- AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
- GROUP BY epa.room_id, epa.thread_id
- """
- txn.execute(sql, args)
+ room_to_count: Dict[str, int] = defaultdict(int)
+ stale_room_ids = set()
+ for row in txn:
+ room_id = row[0]
+ notif_count = row[1]
+ stream_ordering = row[2]
+ _thread_id = row[3]
+ last_receipt_stream_ordering = row[4]
+ receipt_stream_ordering = row[5]
+
+ if last_receipt_stream_ordering is None:
+ if receipt_stream_ordering is None:
+ room_to_count[room_id] += notif_count
+ elif stream_ordering > receipt_stream_ordering:
+ room_to_count[room_id] += notif_count
+ else:
+ # The latest read receipt from the user is after all the rows for
+ # this room in `event_push_summary`. We ignore them, and
+ # calculate the count from `event_push_actions` in step 3.
+ pass
+ elif last_receipt_stream_ordering == receipt_stream_ordering:
+ room_to_count[room_id] += notif_count
+ else:
+ # The row is stale if `last_receipt_stream_ordering` is set and
+ # *doesn't* match the latest receipt from the user.
+ stale_room_ids.add(room_id)
- for room_id, thread_id, notif_count in txn:
- # Note: only count push actions we have valid summaries for with up to date receipt.
- if thread_id not in seen_thread_ids:
- continue
- room_to_count[room_id] += notif_count
+ # Discard any stale rooms from `room_to_count`, as we will recalculate
+ # them in step 3.
+ for room_id in stale_room_ids:
+ room_to_count.pop(room_id, None)
- thread_id_clause, thread_ids_args = make_in_list_sql_clause(
- self.database_engine, "epa.thread_id", seen_thread_ids
+ # Step 2, basically the same query, except against `event_push_actions`
+ # and only fetching rows inserted since the last rotation.
+ rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
)
- # Finally re-check event_push_actions for any rooms not in the summary, ignoring
- # the rotated up-to position. This handles the case where a read receipt has arrived
- # but not been rotated meaning the summary table is out of date, so we go back to
- # the push actions table.
sql = f"""
- {receipts_cte}
- SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
- FROM event_push_actions AS epa
- {receipts_joins}
- WHERE user_id = ?
- AND NOT {thread_id_clause}
- AND epa.notif = 1
- AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
- AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
- GROUP BY epa.room_id
+ SELECT room_id, thread_id
+ FROM (
+ SELECT e.room_id, e.stream_ordering, e.thread_id,
+ ev.stream_ordering AS receipt_stream_ordering
+ FROM event_push_actions AS e
+ INNER JOIN local_current_membership USING (user_id, room_id)
+ LEFT JOIN receipts_linearized AS r ON (
+ e.user_id = r.user_id
+ AND e.room_id = r.room_id
+ AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+ AND {receipt_types_clause}
+ )
+ LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+ WHERE e.user_id = ? and notif > 0
+ AND e.stream_ordering > ?
+ ) AS es
+ GROUP BY room_id, stream_ordering, thread_id
+ HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0)
"""
- args.extend(thread_ids_args)
- txn.execute(sql, args)
+ txn.execute(
+ sql,
+ receipt_types_args + [user_id, rotated_upto_stream_ordering],
+ )
+ for room_id, _thread_id in txn:
+ # Again, we ignore any stale rooms.
+ if room_id not in stale_room_ids:
+ # For event push actions it is one notification per row.
+ room_to_count[room_id] += 1
+
+ # Step 3, if we have stale rooms then we need to recalculate the counts
+ # from `event_push_actions`. Again, this is basically the same query as
+ # above except without a lower bound on stream ordering and only against
+ # a specific set of rooms.
+ if stale_room_ids:
+ room_id_clause, room_id_args = make_in_list_sql_clause(
+ self.database_engine,
+ "e.room_id",
+ stale_room_ids,
+ )
- for room_id, notif_count in txn:
- room_to_count[room_id] += notif_count
+ sql = f"""
+ SELECT room_id, thread_id
+ FROM (
+ SELECT e.room_id, e.stream_ordering, e.thread_id,
+ ev.stream_ordering AS receipt_stream_ordering
+ FROM event_push_actions AS e
+ INNER JOIN local_current_membership USING (user_id, room_id)
+ LEFT JOIN receipts_linearized AS r ON (
+ e.user_id = r.user_id
+ AND e.room_id = r.room_id
+ AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+ AND {receipt_types_clause}
+ )
+ LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+ WHERE e.user_id = ? and notif > 0
+ AND {room_id_clause}
+ ) AS es
+ GROUP BY room_id, stream_ordering, thread_id
+ HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0)
+ """
+ txn.execute(
+ sql,
+ receipt_types_args + [user_id] + room_id_args,
+ )
+ for room_id, _ in txn:
+ room_to_count[room_id] += 1
return room_to_count
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4700e74ad2..8006046453 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -24,13 +24,17 @@ from typing import (
Any,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
+ TypeVar,
+ Union,
cast,
+ overload,
)
import attr
@@ -52,7 +56,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict, JsonMapping, StateMap
+from synapse.types import JsonDict, JsonMapping, StateKey, StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -64,6 +68,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+_T = TypeVar("_T")
+
MAX_STATE_DELTA_HOPS = 100
@@ -318,6 +324,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_partial_current_state_ids", _get_current_state_ids_txn
)
+ async def check_if_events_in_current_state(
+ self, event_ids: StrCollection
+ ) -> FrozenSet[str]:
+ """Checks and returns which of the given events is part of the current state."""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("event_id",),
+ desc="check_if_events_in_current_state",
+ )
+
+ return frozenset(event_id for event_id, in rows)
+
# FIXME: how should this be cached?
@cancellable
async def get_partial_filtered_current_state_ids(
@@ -349,7 +369,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
- results = {}
+ results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
+
sql = """
SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@@ -726,3 +747,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
+
+
+@attr.s(auto_attribs=True, slots=True)
+class StateMapWrapper(Dict[StateKey, str]):
+ """A wrapper around a StateMap[str] to ensure that we only query for items
+ that were not filtered out.
+
+ This is to help prevent bugs where we filter out state but other bits of the
+ code expect the state to be there.
+ """
+
+ state_filter: StateFilter
+
+ def __getitem__(self, key: StateKey) -> str:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+ return super().__getitem__(key)
+
+ @overload
+ def get(self, key: Tuple[str, str]) -> Optional[str]:
+ ...
+
+ @overload
+ def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
+ ...
+
+ def get(
+ self, key: StateKey, default: Union[str, _T, None] = None
+ ) -> Union[str, _T, None]:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+ return super().get(key, default)
+
+ def __contains__(self, key: Any) -> bool:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+
+ return super().__contains__(key)
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 5ca3c94bce..53662372af 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -20,6 +20,7 @@
import logging
from typing import (
TYPE_CHECKING,
+ Any,
Callable,
Collection,
Dict,
@@ -584,6 +585,29 @@ class StateFilter:
# local users only
return False
+ def __contains__(self, key: Any) -> bool:
+ if not isinstance(key, tuple) or len(key) != 2:
+ raise TypeError(
+ f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}"
+ )
+
+ typ, state_key = key
+
+ if not isinstance(typ, str) or not isinstance(state_key, str):
+ raise TypeError(
+ f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})"
+ )
+
+ if typ in self.types:
+ state_keys = self.types[typ]
+ if state_keys is None or state_key in state_keys:
+ return True
+
+ elif self.include_others:
+ return True
+
+ return False
+
_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|