diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef475e18c7..52146aacc8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -26,6 +26,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -105,6 +106,11 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"AccountDataAndTagsChangeCache", account_max
)
+ self.db_pool.updates.register_background_update_handler(
+ "delete_account_data_for_deactivated_users",
+ self._delete_account_data_for_deactivated_users,
+ )
+
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
@@ -158,9 +164,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cached(num_args=2, max_entries=5000)
+ @cached(num_args=2, max_entries=5000, tree=True)
async def get_global_account_data_by_type_for_user(
- self, data_type: str, user_id: str
+ self, user_id: str, data_type: str
) -> Optional[JsonDict]:
"""
Returns:
@@ -179,7 +185,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
else:
return None
- @cached(num_args=2)
+ @cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
@@ -210,7 +216,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- @cached(num_args=3, max_entries=5000)
+ @cached(num_args=3, max_entries=5000, tree=True)
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
@@ -392,7 +398,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id)
+ (row.user_id, row.data_type)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
@@ -476,7 +482,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
- (account_data_type, user_id)
+ (user_id, account_data_type)
)
return self._account_data_id_gen.get_current_token()
@@ -546,6 +552,123 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ async def purge_account_data_for_user(self, user_id: str) -> None:
+ """
+ Removes ALL the account data for a user.
+ Intended to be used upon user deactivation.
+
+ Also purges the user from the ignored_users cache table
+ and the push_rules cache tables.
+ """
+
+ await self.db_pool.runInteraction(
+ "purge_account_data_for_user_txn",
+ self._purge_account_data_for_user_txn,
+ user_id,
+ )
+
+ def _purge_account_data_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
+ """
+ See `purge_account_data_for_user`.
+ """
+ # Purge from the primary account_data tables.
+ self.db_pool.simple_delete_txn(
+ txn, table="account_data", keyvalues={"user_id": user_id}
+ )
+
+ self.db_pool.simple_delete_txn(
+ txn, table="room_account_data", keyvalues={"user_id": user_id}
+ )
+
+ # Purge from ignored_users where this user is the ignorer.
+ # N.B. We don't purge where this user is the ignoree, because that
+ # interferes with other users' account data.
+ # It's also not this user's data to delete!
+ self.db_pool.simple_delete_txn(
+ txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
+ )
+
+ # Remove the push rules
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_enable", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_stream", keyvalues={"user_id": user_id}
+ )
+
+ # Invalidate caches as appropriate
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room_and_type, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_global_account_data_by_type_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room, (user_id,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_push_rules_enabled_for_user, (user_id,)
+ )
+ # This user might be contained in the ignored_by cache for other users,
+ # so we have to invalidate it all.
+ self._invalidate_all_cache_and_stream(txn, self.ignored_by)
+
+ async def _delete_account_data_for_deactivated_users(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """
+ Retroactively purges account data for users that have already been deactivated.
+ Gets run as a background update caused by a schema delta.
+ """
+
+ last_user: str = progress.get("last_user", "")
+
+ def _delete_account_data_for_deactivated_users_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ sql = """
+ SELECT name FROM users
+ WHERE deactivated = ? and name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (1, last_user, batch_size))
+ users = [row[0] for row in txn]
+
+ for user in users:
+ self._purge_account_data_for_user_txn(txn, user_id=user)
+
+ if users:
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ "delete_account_data_for_deactivated_users",
+ {"last_user": users[-1]},
+ )
+
+ return len(users)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_delete_account_data_for_deactivated_users",
+ _delete_account_data_for_deactivated_users_txn,
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "delete_account_data_for_deactivated_users"
+ )
+
+ return number_deleted
+
class AccountDataStore(AccountDataWorkerStore):
pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 92c95a41d7..304814af5d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict],
+ to_device_messages: List[JsonDict],
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
service: The service who the transaction is for.
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
+ to_device_messages: A list of to-device messages to put in the transaction.
Returns:
A new transaction.
@@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
(service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(
- service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+ service=service,
+ id=new_txn_id,
+ events=events,
+ ephemeral=ephemeral,
+ to_device_messages=to_device_messages,
)
return await self.db_pool.runInteraction(
@@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(
- service=service, id=entry["txn_id"], events=events, ephemeral=[]
+ service=service,
+ id=entry["txn_id"],
+ events=events,
+ ephemeral=[],
+ to_device_messages=[],
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -384,14 +394,14 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = await self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids, get_prev_content=True)
return upper_bound, events
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence"):
+ if type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
- async def set_type_stream_id_for_appservice(
+ async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence"):
+ if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
- def set_type_stream_id_for_appservice_txn(txn):
+ def set_appservice_stream_type_pos_txn(txn):
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
@@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
)
await self.db_pool.runInteraction(
- "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+ "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 0024348067..c428dd5596 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ def _invalidate_state_caches_and_stream(
+ self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str]
+ ) -> None:
"""Special case invalidation of caches based on current state.
We special case this so that we can batch the cache invalidations into a
@@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
Args:
txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4eca97189b..8801b7b2dd 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
- async def get_new_messages_for_device(
+ async def get_messages_for_user_devices(
+ self,
+ user_ids: Collection[str],
+ from_stream_id: int,
+ to_stream_id: int,
+ ) -> Dict[Tuple[str, str], List[JsonDict]]:
+ """
+ Retrieve to-device messages for a given set of users.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
+ Args:
+ user_ids: The users to retrieve to-device messages for.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+ Returns:
+ A dictionary of (user id, device id) -> list of to-device messages.
+ """
+ # We expect the stream ID returned by _get_device_messages to always
+ # be to_stream_id. So, no need to return it from this function.
+ (
+ user_id_device_id_to_messages,
+ last_processed_stream_id,
+ ) = await self._get_device_messages(
+ user_ids=user_ids,
+ from_stream_id=from_stream_id,
+ to_stream_id=to_stream_id,
+ )
+
+ assert (
+ last_processed_stream_id == to_stream_id
+ ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
+
+ return user_id_device_id_to_messages
+
+ async def get_messages_for_device(
self,
user_id: str,
- device_id: Optional[str],
- last_stream_id: int,
- current_stream_id: int,
+ device_id: str,
+ from_stream_id: int,
+ to_stream_id: int,
limit: int = 100,
- ) -> Tuple[List[dict], int]:
+ ) -> Tuple[List[JsonDict], int]:
"""
+ Retrieve to-device messages for a single user device.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
Args:
- user_id: The recipient user_id.
- device_id: The recipient device_id.
- last_stream_id: The last stream ID checked.
- current_stream_id: The current position of the to device
- message stream.
- limit: The maximum number of messages to retrieve.
+ user_id: The ID of the user to retrieve messages for.
+ device_id: The ID of the device to retrieve to-device messages for.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+ limit: A limit on the number of to-device messages returned.
Returns:
A tuple containing:
- * A list of messages for the device.
- * The max stream token of these messages. There may be more to retrieve
- if the given limit was reached.
+ * A list of to-device messages within the given stream id range intended for
+ the given user / device combo.
+ * The last-processed stream ID. Subsequent calls of this function with the
+ same device should pass this value as 'from_stream_id'.
"""
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_stream_id
+ (
+ user_id_device_id_to_messages,
+ last_processed_stream_id,
+ ) = await self._get_device_messages(
+ user_ids=[user_id],
+ device_id=device_id,
+ from_stream_id=from_stream_id,
+ to_stream_id=to_stream_id,
+ limit=limit,
)
- if not has_changed:
- return [], current_stream_id
- def get_new_messages_for_device_txn(txn):
- sql = (
- "SELECT stream_id, message_json FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
+ if not user_id_device_id_to_messages:
+ # There were no messages!
+ return [], to_stream_id
+
+ # Extract the messages, no need to return the user and device ID again
+ to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+
+ return to_device_messages, last_processed_stream_id
+
+ async def _get_device_messages(
+ self,
+ user_ids: Collection[str],
+ from_stream_id: int,
+ to_stream_id: int,
+ device_id: Optional[str] = None,
+ limit: Optional[int] = None,
+ ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+ """
+ Retrieve pending to-device messages for a collection of user devices.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
+ Note that a stream ID can be shared by multiple copies of the same message with
+ different recipient devices. Stream IDs are only unique in the context of a single
+ user ID / device ID pair. Thus, applying a limit (of messages to return) when working
+ with a sliding window of stream IDs is only possible when querying messages of a
+ single user device.
+
+ Finally, note that device IDs are not unique across users.
+
+ Args:
+ user_ids: The user IDs to filter device messages by.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+ device_id: A device ID to query to-device messages for. If not provided, to-device
+ messages from all device IDs for the given user IDs will be queried. May not be
+ provided if `user_ids` contains more than one entry.
+ limit: The maximum number of to-device messages to return. Can only be used when
+ passing a single user ID / device ID tuple.
+
+ Returns:
+ A tuple containing:
+ * A dict of (user_id, device_id) -> list of to-device messages
+ * The last-processed stream ID. If this is less than `to_stream_id`, then
+ there may be more messages to retrieve. If `limit` is not set, then this
+ is always equal to 'to_stream_id'.
+ """
+ if not user_ids:
+ logger.warning("No users provided upon querying for device IDs")
+ return {}, to_stream_id
+
+ # Prevent a query for one user's device also retrieving another user's device with
+ # the same device ID (device IDs are not unique across users).
+ if len(user_ids) > 1 and device_id is not None:
+ raise AssertionError(
+ "Programming error: 'device_id' cannot be supplied to "
+ "_get_device_messages when >1 user_id has been provided"
)
- txn.execute(
- sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+
+ # A limit can only be applied when querying for a single user ID / device ID tuple.
+ # See the docstring of this function for more details.
+ if limit is not None and device_id is None:
+ raise AssertionError(
+ "Programming error: _get_device_messages was passed 'limit' "
+ "without a specific user_id/device_id"
)
- messages = []
- stream_pos = current_stream_id
+ user_ids_to_query: Set[str] = set()
+ device_ids_to_query: Set[str] = set()
+
+ # Note that a device ID could be an empty str
+ if device_id is not None:
+ # If a device ID was passed, use it to filter results.
+ # Otherwise, device IDs will be derived from the given collection of user IDs.
+ device_ids_to_query.add(device_id)
+
+ # Determine which users have devices with pending messages
+ for user_id in user_ids:
+ if self._device_inbox_stream_cache.has_entity_changed(
+ user_id, from_stream_id
+ ):
+ # This user has new messages sent to them. Query messages for them
+ user_ids_to_query.add(user_id)
+
+ def get_device_messages_txn(txn: LoggingTransaction):
+ # Build a query to select messages from any of the given devices that
+ # are between the given stream id bounds.
+
+ # If a list of device IDs was not provided, retrieve all devices IDs
+ # for the given users. We explicitly do not query hidden devices, as
+ # hidden devices should not receive to-device messages.
+ # Note that this is more efficient than just dropping `device_id` from the query,
+ # since device_inbox has an index on `(user_id, device_id, stream_id)`
+ if not device_ids_to_query:
+ user_device_dicts = self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("device_id",),
+ )
- for row in txn:
- stream_pos = row[0]
- messages.append(db_to_json(row[1]))
+ device_ids_to_query.update(
+ {row["device_id"] for row in user_device_dicts}
+ )
- # If the limit was not reached we know that there's no more data for this
- # user/device pair up to current_stream_id.
- if len(messages) < limit:
- stream_pos = current_stream_id
+ if not device_ids_to_query:
+ # We've ended up with no devices to query.
+ return {}, to_stream_id
- return messages, stream_pos
+ # We include both user IDs and device IDs in this query, as we have an index
+ # (device_inbox_user_stream_id) for them.
+ user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
+ self.database_engine, "user_id", user_ids_to_query
+ )
+ (
+ device_id_many_clause_sql,
+ device_id_many_clause_args,
+ ) = make_in_list_sql_clause(
+ self.database_engine, "device_id", device_ids_to_query
+ )
+
+ sql = f"""
+ SELECT stream_id, user_id, device_id, message_json FROM device_inbox
+ WHERE {user_id_many_clause_sql}
+ AND {device_id_many_clause_sql}
+ AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ sql_args = (
+ *user_id_many_clause_args,
+ *device_id_many_clause_args,
+ from_stream_id,
+ to_stream_id,
+ )
+
+ # If a limit was provided, limit the data retrieved from the database
+ if limit is not None:
+ sql += "LIMIT ?"
+ sql_args += (limit,)
+
+ txn.execute(sql, sql_args)
+
+ # Create and fill a dictionary of (user ID, device ID) -> list of messages
+ # intended for each device.
+ last_processed_stream_pos = to_stream_id
+ recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
+ for row in txn:
+ last_processed_stream_pos = row[0]
+ recipient_user_id = row[1]
+ recipient_device_id = row[2]
+ message_dict = db_to_json(row[3])
+
+ # Store the device details
+ recipient_device_to_messages.setdefault(
+ (recipient_user_id, recipient_device_id), []
+ ).append(message_dict)
+
+ if limit is not None and txn.rowcount == limit:
+ # We ended up bumping up against the message limit. There may be more messages
+ # to retrieve. Return what we have, as well as the last stream position that
+ # was processed.
+ #
+ # The caller is expected to set this as the lower (exclusive) bound
+ # for the next query of this device.
+ return recipient_device_to_messages, last_processed_stream_pos
+
+ # The limit was not reached, thus we know that recipient_device_to_messages
+ # contains all to-device messages for the given device and stream id range.
+ #
+ # We return to_stream_id, which the caller should then provide as the lower
+ # (exclusive) bound on the next query of this device.
+ return recipient_device_to_messages, to_stream_id
return await self.db_pool.runInteraction(
- "get_new_messages_for_device", get_new_messages_for_device_txn
+ "get_device_messages", get_device_messages_txn
)
@trace
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a556f17dac..ca71f073fc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -65,7 +65,7 @@ class _NoChainCoverIndex(Exception):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
-class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1ae1ebe108..b7554154ac 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1389,6 +1389,8 @@ class PersistEventsStore:
"received_ts",
"sender",
"contains_url",
+ "state_key",
+ "rejection_reason",
),
values=(
(
@@ -1405,8 +1407,10 @@ class PersistEventsStore:
self._clock.time_msec(),
event.sender,
"url" in event.content and isinstance(event.content["url"], str),
+ event.get_state_key(),
+ context.rejected or None,
)
- for event, _ in events_and_contexts
+ for event, context in events_and_contexts
),
)
@@ -1456,6 +1460,7 @@ class PersistEventsStore:
for event, context in events_and_contexts:
if context.rejected:
# Insert the event_id into the rejections table
+ # (events.rejection_reason has already been done)
self._store_rejections_txn(txn, event.event_id, context.rejected)
to_remove.add(event)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..e87a8fb85d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -390,7 +390,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_search",
"events",
"group_rooms",
- "public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e01c94930a..92539f5d41 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _load_rules(rawrules, enabled_map, use_new_defaults=False):
+def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -52,7 +52,7 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist, use_new_defaults))
+ rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule["rule_id"]
@@ -112,10 +112,6 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- self._users_new_default_push_rules = (
- hs.config.server.users_new_default_push_rules
- )
-
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
@@ -145,9 +141,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- use_new_defaults = user_id in self._users_new_default_push_rules
-
- return _load_rules(rows, enabled_map, use_new_defaults)
+ return _load_rules(rows, enabled_map)
@cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
@@ -206,13 +200,7 @@ class PushRulesWorkerStore(
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
- use_new_defaults = user_id in self._users_new_default_push_rules
-
- results[user_id] = _load_rules(
- rules,
- enabled_map_by_user.get(user_id, {}),
- use_new_defaults,
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..37468a5183 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterable,
- List,
- Optional,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
import attr
from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ latest_event: EventBase
+ count: int
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -60,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
@cached(tree=True)
@@ -585,7 +599,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +630,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
- aggregations: Dict[str, Any] = {}
+ aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+ aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
+ aggregations.references = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)
if edit:
- aggregations[RelationTypes.REPLACE] = edit
+ aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
@@ -644,11 +658,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id
)
if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- "latest_event": latest_thread_event,
- "count": thread_count,
- "current_user_participated": participated,
- }
+ aggregations.thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ current_user_participated=participated,
+ )
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@@ -657,7 +671,7 @@ class RelationsWorkerStore(SQLBaseStore):
self,
events: Iterable[EventBase],
user_id: str,
- ) -> Dict[str, Dict[str, Any]]:
+ ) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
@@ -668,15 +682,12 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
- # If bundled aggregations are disabled, nothing to do.
- if not self._msc1849_enabled:
- return {}
# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result is not None:
+ if event_result:
results[event.event_id] = event_result
return results
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 3201623fe4..0518b8b910 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable, List, Tuple
+from typing import Collection, Dict, List, Tuple
from unpaddedbase64 import encode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.types import Cursor
+from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage.databases.main.events_worker import (
+ EventRedactBehaviour,
+ EventsWorkerStore,
+)
from synapse.util.caches.descriptors import cached, cachedList
-class SignatureWorkerStore(SQLBaseStore):
+class SignatureWorkerStore(EventsWorkerStore):
@cached()
def get_event_reference_hash(self, event_id):
# This is a dummy function to allow get_event_reference_hashes
@@ -32,7 +35,7 @@ class SignatureWorkerStore(SQLBaseStore):
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
async def get_event_reference_hashes(
- self, event_ids: Iterable[str]
+ self, event_ids: Collection[str]
) -> Dict[str, Dict[str, bytes]]:
"""Get all hashes for given events.
@@ -41,18 +44,27 @@ class SignatureWorkerStore(SQLBaseStore):
Returns:
A mapping of event ID to a mapping of algorithm to hash.
+ Returns an empty dict for a given event id if that event is unknown.
"""
+ events = await self.get_events(
+ event_ids,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ allow_rejected=True,
+ )
- def f(txn):
- return {
- event_id: self._get_event_reference_hashes_txn(txn, event_id)
- for event_id in event_ids
- }
+ hashes: Dict[str, Dict[str, bytes]] = {}
+ for event_id in event_ids:
+ event = events.get(event_id)
+ if event is None:
+ hashes[event_id] = {}
+ else:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ hashes[event_id] = {ref_alg: ref_hash_bytes}
- return await self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return hashes
async def add_event_hashes(
- self, event_ids: Iterable[str]
+ self, event_ids: Collection[str]
) -> List[Tuple[str, Dict[str, str]]]:
"""
@@ -70,24 +82,6 @@ class SignatureWorkerStore(SQLBaseStore):
return list(encoded_hashes.items())
- def _get_event_reference_hashes_txn(
- self, txn: Cursor, event_id: str
- ) -> Dict[str, bytes]:
- """Get all the hashes for a given PDU.
- Args:
- txn:
- event_id: Id for the Event.
- Returns:
- A mapping of algorithm -> hash.
- """
- query = (
- "SELECT algorithm, hash"
- " FROM event_reference_hashes"
- " WHERE event_id = ?"
- )
- txn.execute(query, (event_id,))
- return {k: v for k, v in txn}
-
class SignatureStore(SignatureWorkerStore):
"""Persistence for event signatures and hashes"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
stream_ordering: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+ events_before: List[EventBase]
+ events_after: List[EventBase]
+ start: RoomStreamToken
+ end: RoomStreamToken
+
+
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
- ) -> dict:
+ ) -> _EventsAround:
"""Retrieve events and pagination tokens around a given event in a
room.
"""
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
list(results["after"]["event_ids"]), get_prev_content=True
)
- return {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
+ return _EventsAround(
+ events_before=events_before,
+ events_after=events_after,
+ start=results["before"]["token"],
+ end=results["after"]["token"],
+ )
def _get_events_around_txn(
self,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 4b78b4d098..ba79e19f7f 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -561,6 +561,54 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
"get_destinations_paginate_txn", get_destinations_paginate_txn
)
+ async def get_destination_rooms_paginate(
+ self, destination: str, start: int, limit: int, direction: str = "f"
+ ) -> Tuple[List[JsonDict], int]:
+ """Function to retrieve a paginated list of destination's rooms.
+ This will return a json list of rooms and the
+ total number of rooms.
+
+ Args:
+ destination: the destination to query
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ direction: sort ascending or descending by room_id
+ Returns:
+ A tuple of a dict of rooms and a count of total rooms.
+ """
+
+ def get_destination_rooms_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT COUNT(*) as total_rooms
+ FROM destination_rooms
+ WHERE destination = ?
+ """
+ txn.execute(sql, [destination])
+ count = cast(Tuple[int], txn.fetchone())[0]
+
+ rooms = self.db_pool.simple_select_list_paginate_txn(
+ txn=txn,
+ table="destination_rooms",
+ orderby="room_id",
+ start=start,
+ limit=limit,
+ retcols=("room_id", "stream_ordering"),
+ order_direction=order,
+ )
+ return rooms, count
+
+ return await self.db_pool.runInteraction(
+ "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn
+ )
+
async def is_destination_known(self, destination: str) -> bool:
"""Check if a destination is known to the server."""
result = await self.db_pool.simple_select_one_onecol(
|