diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 73c95ffb6f..48a54d9cb8 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -26,8 +26,15 @@ from typing import (
cast,
)
+from synapse.api.constants import EventContentFields
from synapse.logging import issue9533_logger
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import (
+ SynapseTags,
+ log_kv,
+ set_tag,
+ start_active_span,
+ trace,
+)
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -397,6 +404,17 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(recipient_user_id, recipient_device_id), []
).append(message_dict)
+ # start a new span for each message, so that we can tag each separately
+ with start_active_span("get_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+
if limit is not None and rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
@@ -678,12 +696,35 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- if remote_messages_by_destination:
- issue9533_logger.debug(
- "Queued outgoing to-device messages with stream_id %i for %s",
- stream_id,
- list(remote_messages_by_destination.keys()),
- )
+ for destination, edu in remote_messages_by_destination.items():
+ if issue9533_logger.isEnabledFor(logging.DEBUG):
+ issue9533_logger.debug(
+ "Queued outgoing to-device messages with "
+ "stream_id %i, EDU message_id %s, type %s for %s: %s",
+ stream_id,
+ edu["message_id"],
+ edu["type"],
+ destination,
+ [
+ f"{user_id}/{device_id} (msgid "
+ f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})"
+ for (user_id, messages_by_device) in edu["messages"].items()
+ for (device_id, msg) in messages_by_device.items()
+ ],
+ )
+
+ for (user_id, messages_by_device) in edu["messages"].items():
+ for (device_id, msg) in messages_by_device.items():
+ with start_active_span("store_outgoing_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
+ set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
+ set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ msg.get(EventContentFields.TO_DEVICE_MSGID),
+ )
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self._clock.time_msec()
@@ -801,7 +842,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Only insert into the local inbox if the device exists on
# this server
device_id = row["device_id"]
- message_json = json_encoder.encode(messages_by_device[device_id])
+
+ with start_active_span("serialise_to_device_message"):
+ msg = messages_by_device[device_id]
+ set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ msg["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+ message_json = json_encoder.encode(msg)
+
messages_json_for_user[device_id] = message_json
if messages_json_for_user:
@@ -821,15 +874,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- issue9533_logger.debug(
- "Stored to-device messages with stream_id %i for %s",
- stream_id,
- [
- (user_id, device_id)
- for (user_id, messages_by_device) in local_by_user_then_device.items()
- for device_id in messages_by_device.keys()
- ],
- )
+ if issue9533_logger.isEnabledFor(logging.DEBUG):
+ issue9533_logger.debug(
+ "Stored to-device messages with stream_id %i: %s",
+ stream_id,
+ [
+ f"{user_id}/{device_id} (msgid "
+ f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})"
+ for (
+ user_id,
+ messages_by_device,
+ ) in messages_by_user_then_device.items()
+ for (device_id, msg) in messages_by_device.items()
+ ],
+ )
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 534f7fc04a..a5bb4d404e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.stream_change_cache import (
+ AllEntitiesChangedResult,
+ StreamChangeCache,
+)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -799,7 +802,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_cached_device_list_changes(
self,
from_key: int,
- ) -> Optional[List[str]]:
+ ) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
@@ -807,10 +810,58 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
+ async def get_all_devices_changed(
+ self,
+ from_key: int,
+ to_key: int,
+ ) -> Set[str]:
+ """Get all users whose devices have changed in the given range.
+
+ Args:
+ from_key: The minimum device lists stream token to query device list
+ changes for, exclusive.
+ to_key: The maximum device lists stream token to query device list
+ changes for, inclusive.
+
+ Returns:
+ The set of user_ids whose devices have changed since `from_key`
+ (exclusive) until `to_key` (inclusive).
+ """
+
+ result = self._device_list_stream_cache.get_all_entities_changed(from_key)
+
+ if result.hit:
+ # We know which users might have changed devices.
+ if not result.entities:
+ # If no users then we can return early.
+ return set()
+
+ # Otherwise we need to filter down the list
+ return await self.get_users_whose_devices_changed(
+ from_key, result.entities, to_key
+ )
+
+ # If the cache didn't tell us anything, we just need to query the full
+ # range.
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+
+ rows = await self.db_pool.execute(
+ "get_all_devices_changed",
+ None,
+ sql,
+ from_key,
+ to_key,
+ )
+ return {u for u, in rows}
+
+ @cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
- user_ids: Optional[Collection[str]] = None,
+ user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@@ -830,46 +881,31 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- user_ids_to_check: Optional[Collection[str]]
- if user_ids is None:
- # Get set of all users that have had device list changes since 'from_key'
- user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
- from_key
- )
- else:
- # The same as above, but filter results to only those users in 'user_ids'
- user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
+ # If an empty set was returned, there's nothing to do.
if not user_ids_to_check:
return set()
- def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
- changes: Set[str] = set()
-
- stream_id_where_clause = "stream_id > ?"
- sql_args = [from_key]
-
- if to_key:
- stream_id_where_clause += " AND stream_id <= ?"
- sql_args.append(to_key)
+ if to_key is None:
+ to_key = self._device_list_id_gen.get_current_token()
- sql = f"""
+ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
+ sql = """
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE {stream_id_where_clause}
- AND
+ WHERE ? < stream_id AND stream_id <= ? AND %s
"""
+ changes: Set[str] = set()
+
# Query device changes with a batch of users at a time
- # Assertion for mypy's benefit; see also
- # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
- assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, sql_args + args)
+ txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)
return changes
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b283ab0f9c..7ebe34f773 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,6 +74,7 @@ receipt.
"""
import logging
+from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -95,6 +96,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -463,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
+ async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
+ """Get the notification count by room for a user. Only considers notifications,
+ not highlight or unread counts, and threads are currently aggregated under their room.
+
+ This function is intentionally not cached because it is called to calculate the
+ unread badge for push notifications and thus the result is expected to change.
+
+ Note that this function assumes the user is a member of the room. Because
+ summary rows are not removed when a user leaves a room, the caller must
+ filter out those results from the result.
+
+ Returns:
+ A map of room ID to notification counts for the given user.
+ """
+ return await self.db_pool.runInteraction(
+ "get_unread_counts_by_room_for_user",
+ self._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+
+ def _get_unread_counts_by_room_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Dict[str, int]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ )
+ args.extend([user_id, user_id])
+
+ receipts_cte = f"""
+ WITH all_receipts AS (
+ SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id, thread_id
+ )
+ """
+
+ receipts_joins = """
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS threaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NOT NULL
+ ) AS threaded_receipts USING (room_id, thread_id)
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NULL
+ ) AS unthreaded_receipts USING (room_id)
+ """
+
+ # First get summary counts by room / thread for the user. We use the max receipt
+ # stream ordering of both threaded & unthreaded receipts to compare against the
+ # summary table.
+ #
+ # PostgreSQL and SQLite differ in comparing scalar numerics.
+ if isinstance(self.database_engine, PostgresEngine):
+ # GREATEST ignores NULLs.
+ max_clause = """GREATEST(
+ threaded_receipt_stream_ordering,
+ unthreaded_receipt_stream_ordering
+ )"""
+ else:
+ # MAX returns NULL if any are NULL, so COALESCE to 0 first.
+ max_clause = """MAX(
+ COALESCE(threaded_receipt_stream_ordering, 0),
+ COALESCE(unthreaded_receipt_stream_ordering, 0)
+ )"""
+
+ sql = f"""
+ {receipts_cte}
+ SELECT eps.room_id, eps.thread_id, notif_count
+ FROM event_push_summary AS eps
+ {receipts_joins}
+ WHERE user_id = ?
+ AND notif_count != 0
+ AND (
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
+ OR last_receipt_stream_ordering = {max_clause}
+ )
+ """
+ txn.execute(sql, args)
+
+ seen_thread_ids = set()
+ room_to_count: Dict[str, int] = defaultdict(int)
+
+ for room_id, thread_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+ seen_thread_ids.add(thread_id)
+
+ # Now get any event push actions that haven't been rotated using the same OR
+ # join and filter by receipt and event push summary rotated up to stream ordering.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND epa.notif = 1
+ AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id, epa.thread_id
+ """
+ txn.execute(sql, args)
+
+ for room_id, thread_id, notif_count in txn:
+ # Note: only count push actions we have valid summaries for with up to date receipt.
+ if thread_id not in seen_thread_ids:
+ continue
+ room_to_count[room_id] += notif_count
+
+ thread_id_clause, thread_ids_args = make_in_list_sql_clause(
+ self.database_engine, "epa.thread_id", seen_thread_ids
+ )
+
+ # Finally re-check event_push_actions for any rooms not in the summary, ignoring
+ # the rotated up-to position. This handles the case where a read receipt has arrived
+ # but not been rotated meaning the summary table is out of date, so we go back to
+ # the push actions table.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND NOT {thread_id_clause}
+ AND epa.notif = 1
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id
+ """
+
+ args.extend(thread_ids_args)
+ txn.execute(sql, args)
+
+ for room_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+
+ return room_to_count
+
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 1309bfd374..78906a5e1d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019, 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -50,8 +50,14 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import IdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ IdGenerator,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -114,6 +120,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self.config: HomeServerConfig = hs.config
+ self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+
+ if isinstance(database.engine, PostgresEngine):
+ self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="un_partial_stated_room_stream",
+ instance_name=self._instance_name,
+ tables=[
+ ("un_partial_stated_room_stream", "instance_name", "stream_id")
+ ],
+ sequence_name="un_partial_stated_room_stream_sequence",
+ # TODO(faster_joins, multiple writers) Support multiple writers.
+ writers=["master"],
+ )
+ else:
+ self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
+ db_conn, "un_partial_stated_room_stream", "stream_id"
+ )
+
async def store_room(
self,
room_id: str,
@@ -1216,70 +1242,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return room_servers
- async def clear_partial_state_room(self, room_id: str) -> bool:
- """Clears the partial state flag for a room.
-
- Args:
- room_id: The room whose partial state flag is to be cleared.
-
- Returns:
- `True` if the partial state flag has been cleared successfully.
-
- `False` if the partial state flag could not be cleared because the room
- still contains events with partial state.
- """
- try:
- await self.db_pool.runInteraction(
- "clear_partial_state_room", self._clear_partial_state_room_txn, room_id
- )
- return True
- except self.db_pool.engine.module.IntegrityError as e:
- # Assume that any `IntegrityError`s are due to partial state events.
- logger.info(
- "Exception while clearing lazy partial-state-room %s, retrying: %s",
- room_id,
- e,
- )
- return False
-
- def _clear_partial_state_room_txn(
- self, txn: LoggingTransaction, room_id: str
- ) -> None:
- DatabasePool.simple_delete_txn(
- txn,
- table="partial_state_rooms_servers",
- keyvalues={"room_id": room_id},
- )
- DatabasePool.simple_delete_one_txn(
- txn,
- table="partial_state_rooms",
- keyvalues={"room_id": room_id},
- )
- self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
- )
-
- # We now delete anything from `device_lists_remote_pending` with a
- # stream ID less than the minimum
- # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
- device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
- txn,
- table="partial_state_rooms",
- keyvalues={},
- retcol="MIN(device_lists_stream_id)",
- allow_none=True,
- )
- if device_lists_stream_id is None:
- # There are no rooms being currently partially joined, so we delete everything.
- txn.execute("DELETE FROM device_lists_remote_pending")
- else:
- sql = """
- DELETE FROM device_lists_remote_pending
- WHERE stream_id <= ?
- """
- txn.execute(sql, (device_lists_stream_id,))
-
@cached()
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@@ -1315,6 +1277,66 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
return result["join_event_id"], result["device_lists_stream_id"]
+ def get_un_partial_stated_rooms_token(self) -> int:
+ # TODO(faster_joins, multiple writers): This is inappropriate if there
+ # are multiple writers because workers that don't write often will
+ # hold all readers up.
+ # (See `MultiWriterIdGenerator.get_persisted_upto_position` for an
+ # explanation.)
+ return self._un_partial_stated_rooms_stream_id_gen.get_current_token()
+
+ async def get_un_partial_stated_rooms_from_stream(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
+ """Get updates for caches replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def get_un_partial_stated_rooms_from_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
+ sql = """
+ SELECT stream_id, room_id
+ FROM un_partial_stated_room_stream
+ WHERE ? < stream_id AND stream_id <= ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
+ updates = [(row[0], (row[1],)) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_rooms_from_stream",
+ get_un_partial_stated_rooms_from_stream_txn,
+ )
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1806,6 +1828,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+ self._instance_name = hs.get_instance_name()
+
async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
) -> None:
@@ -2270,3 +2294,84 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self.is_room_blocked,
(room_id,),
)
+
+ async def clear_partial_state_room(self, room_id: str) -> bool:
+ """Clears the partial state flag for a room.
+
+ Args:
+ room_id: The room whose partial state flag is to be cleared.
+
+ Returns:
+ `True` if the partial state flag has been cleared successfully.
+
+ `False` if the partial state flag could not be cleared because the room
+ still contains events with partial state.
+ """
+ try:
+ async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id:
+ await self.db_pool.runInteraction(
+ "clear_partial_state_room",
+ self._clear_partial_state_room_txn,
+ room_id,
+ un_partial_state_room_stream_id,
+ )
+ return True
+ except self.db_pool.engine.module.IntegrityError as e:
+ # Assume that any `IntegrityError`s are due to partial state events.
+ logger.info(
+ "Exception while clearing lazy partial-state-room %s, retrying: %s",
+ room_id,
+ e,
+ )
+ return False
+
+ def _clear_partial_state_room_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ un_partial_state_room_stream_id: int,
+ ) -> None:
+ DatabasePool.simple_delete_txn(
+ txn,
+ table="partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ )
+ DatabasePool.simple_delete_one_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ )
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_partial_state_servers_at_join, (room_id,)
+ )
+
+ DatabasePool.simple_insert_txn(
+ txn,
+ "un_partial_stated_room_stream",
+ {
+ "stream_id": un_partial_state_room_stream_id,
+ "instance_name": self._instance_name,
+ "room_id": room_id,
+ },
+ )
+
+ # We now delete anything from `device_lists_remote_pending` with a
+ # stream ID less than the minimum
+ # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
+ device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={},
+ retcol="MIN(device_lists_stream_id)",
+ allow_none=True,
+ )
+ if device_lists_stream_id is None:
+ # There are no rooms being currently partially joined, so we delete everything.
+ txn.execute("DELETE FROM device_lists_remote_pending")
+ else:
+ sql = """
+ DELETE FROM device_lists_remote_pending
+ WHERE stream_id <= ?
+ """
+ txn.execute(sql, (device_lists_stream_id,))
|