diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
- FilteringStore,
+ FilteringWorkerStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..a9843f6e17 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
cast,
@@ -39,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -63,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
):
super().__init__(database, db_conn, hs)
- # `_can_write_to_account_data` indicates whether the current worker is allowed
- # to write account data. A value of `True` implies that `_account_data_id_gen`
- # is an `AbstractStreamIdGenerator` and not just a tracker.
- self._account_data_id_gen: AbstractStreamIdTracker
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
+ self._account_data_id_gen: AbstractStreamIdGenerator
+
if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -122,25 +120,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
@cached()
- async def get_account_data_for_user(
+ async def get_global_account_data_for_user(
self, user_id: str
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Mapping[str, JsonDict]:
"""
- Get all the client account_data for a user.
+ Get all the global client account_data for a user.
If experimental MSC3391 support is enabled, any entries with an empty
content body are excluded; as this means they have been deleted.
Args:
user_id: The user to get the account_data for.
+
Returns:
- A 2-tuple of a dict of global account_data and a dict mapping from
- room_id string to per room account_data dicts.
+ The global account_data.
"""
- def get_account_data_for_user_txn(
+ def get_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Dict[str, JsonDict]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -158,10 +156,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
- global_account_data = {
+ return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
+ return await self.db_pool.runInteraction(
+ "get_global_account_data_for_user", get_global_account_data_for_user
+ )
+
+ @cached()
+ async def get_room_account_data_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
+ """
+ Get all of the per-room client account_data for a user.
+
+ If experimental MSC3391 support is enabled, any entries with an empty
+ content body are excluded; as this means they have been deleted.
+
+ Args:
+ user_id: The user to get the account_data for.
+
+ Returns:
+ A dict mapping from room_id string to per-room account_data dicts.
+ """
+
+ def get_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -185,10 +207,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_data[row["account_data_type"]] = db_to_json(row["content"])
- return global_account_data, by_room
+ return by_room
return await self.db_pool.runInteraction(
- "get_account_data_for_user", get_account_data_for_user_txn
+ "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
)
@cached(num_args=2, max_entries=5000, tree=True)
@@ -212,10 +234,41 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
return None
+ async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+ self, user_id: str, data_type: str
+ ) -> Optional[int]:
+ """
+ Returns:
+ The stream ID of the account data,
+ or None if there is no such account data.
+ """
+
+ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
+ sql = """
+ SELECT stream_id FROM account_data
+ WHERE user_id = ? AND account_data_type = ?
+ ORDER BY stream_id DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (user_id, data_type))
+
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ else:
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_latest_stream_id_for_global_account_data_by_type_for_user",
+ get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+ )
+
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
- ) -> Dict[str, JsonDict]:
+ ) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
@@ -342,36 +395,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- async def get_updated_account_data_for_user(
+ async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a that's changed for a user
+ ) -> Dict[str, JsonDict]:
+ """Get all the global account_data that's changed for a user.
Args:
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
+
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A dict of global account_data.
"""
- def get_updated_account_data_for_user_txn(
+ def get_updated_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- sql = (
- "SELECT account_data_type, content FROM account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
-
+ ) -> Dict[str, JsonDict]:
+ sql = """
+ SELECT account_data_type, content FROM account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+ return {row[0]: db_to_json(row[1]) for row in txn}
- sql = (
- "SELECT room_id, account_data_type, content FROM room_account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return {}
+
+ return await self.db_pool.runInteraction(
+ "get_updated_global_account_data_for_user",
+ get_updated_global_account_data_for_user,
+ )
+
+ async def get_updated_room_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Get all the room account_data that's changed for a user.
+ Args:
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
+
+ Returns:
+ A dict mapping from room_id string to per room account_data dicts.
+ """
+
+ def get_updated_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ sql = """
+ SELECT room_id, account_data_type, content FROM room_account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +457,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
- return global_account_data, account_data_by_room
+ return account_data_by_room
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
- return {}, {}
+ return {}
return await self.db_pool.runInteraction(
- "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+ "get_updated_room_account_data_for_user",
+ get_updated_room_account_data_for_user_txn,
)
@cached(max_entries=5000, iterable=True)
@@ -444,7 +523,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self.get_global_account_data_by_type_for_user.invalidate(
(row.user_id, row.data_type)
)
- self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_global_account_data_for_user.invalidate((row.user_id,))
+ self.get_room_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
@@ -475,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -492,7 +571,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), content
@@ -502,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str
- ) -> Optional[int]:
+ ) -> int:
"""Delete the room account data for the user of a given type.
Args:
@@ -515,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
data to delete.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_room_txn(
txn: LoggingTransaction, next_id: int
@@ -554,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
next_id,
)
- if not row_updated:
- return None
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
- self.get_account_data_for_room.invalidate((user_id, room_id))
- self.get_account_data_for_room_and_type.prefill(
- (user_id, room_id, account_data_type), {}
- )
+ if row_updated:
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_room_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type), {}
+ )
return self._account_data_id_gen.get_current_token()
@@ -580,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
@@ -593,7 +668,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(user_id, account_data_type)
)
@@ -670,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self,
user_id: str,
account_data_type: str,
- ) -> Optional[int]:
+ ) -> int:
"""
Delete a single piece of user account data by type.
@@ -687,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
to delete.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_user_txn(
txn: LoggingTransaction, next_id: int
@@ -757,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
next_id,
)
- if not row_updated:
- return None
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
- self.get_global_account_data_by_type_for_user.prefill(
- (user_id, account_data_type), {}
- )
+ if row_updated:
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_global_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.prefill(
+ (user_id, account_data_type), {}
+ )
return self._account_data_id_gen.get_current_token()
@@ -822,7 +894,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room_and_type, (user_id,)
)
self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_user, (user_id,)
+ txn, self.get_global_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_room_account_data_for_user, (user_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_global_account_data_by_type_for_user, (user_id,)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
- ) -> List[str]:
+ ) -> Sequence[str]:
"""
Get all users in a room that the appservice controls.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 5b66431691..096dec7f87 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
- self._attempt_to_invalidate_cache(
- "get_aggregation_groups_for_event", (relates_to,)
- )
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- for (user_id, messages_by_device) in edu["messages"].items():
- for (device_id, msg) in messages_by_device.items():
+ for user_id, messages_by_device in edu["messages"].items():
+ for device_id, msg in messages_by_device.items():
with start_active_span("store_outgoing_to_device_message"):
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
) -> Tuple[int, bool]:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..5503621ad6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -51,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
@@ -90,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ self._device_list_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
@@ -100,6 +100,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
+ ("device_lists_changes_converted_stream_position", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
@@ -201,7 +202,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
- async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ async def count_devices_by_users(
+ self, user_ids: Optional[Collection[str]] = None
+ ) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@@ -212,7 +215,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
def count_devices_by_users_txn(
- txn: LoggingTransaction, user_ids: List[str]
+ txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
@@ -508,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
results.append(("org.matrix.signing_key_update", result))
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- for (user_id, edu) in results:
+ for user_id, edu in results:
issue_8631_logger.debug(
"device update to %s for %s from %s to %s: %s",
destination,
@@ -708,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The new stream ID.
"""
- # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
- # than DeviceWorkerStore?
- async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -745,42 +746,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@trace
@cancellable
async def get_user_devices_from_cache(
- self, query_list: List[Tuple[str, Optional[str]]]
- ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+ self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list: List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
+ user_ids: users which should have all device IDs returned
+ user_and_device_ids: List of (user_id, device_ids)
Returns:
A tuple of (user_ids_not_in_cache, results_map), where
user_ids_not_in_cache is a set of user_ids and results_map is a
mapping of user_id -> device_id -> device_info.
"""
- user_ids = {user_id for user_id, _ in query_list}
- user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+ user_map = await self.get_device_list_last_stream_id_for_remotes(
+ list(unique_user_ids)
+ )
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
- user_ids
+ unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
- user_ids_not_in_cache = user_ids - user_ids_in_cache
-
- results: Dict[str, Dict[str, JsonDict]] = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
+ user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
- if device_id:
- device = await self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
+ # First fetch all the users which all devices are to be returned.
+ results: Dict[str, Mapping[str, JsonDict]] = {}
+ for user_id in user_ids:
+ if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
+ # Then fetch all device-specific requests, but skip users we've already
+ # fetched all devices for.
+ device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+ for user_id, device_id in user_and_device_ids:
+ if user_id in user_ids_in_cache and user_id not in user_ids:
+ device = await self._get_cached_user_device(user_id, device_id)
+ device_specific_results.setdefault(user_id, {})[device_id] = device
+ results.update(device_specific_results)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +804,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
- async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
@@ -1307,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
"""
count = 0
- for (destination, user_id, stream_id, device_id) in rows:
+ for destination, user_id, stream_id, device_id in rows:
txn.execute(
delete_sql, (destination, user_id, stream_id, stream_id, device_id)
)
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
import attr
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
)
@cached(max_entries=5000)
- async def get_aliases_for_room(self, room_id: str) -> List[str]:
+ async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No backup with that version exists")
values = []
- for (room_id, session_id, room_key) in room_keys:
+ for room_id, session_id, room_key in room_keys:
values.append(
(
user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..a3b6c8ae8e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
+ Sequence,
Tuple,
Union,
cast,
@@ -242,9 +244,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
- result = await self.db_pool.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
+ result = await self._get_e2e_device_keys(
query_list,
include_all_devices,
include_deleted_devices,
@@ -260,13 +260,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
- "get_e2e_cross_signing_signatures",
+ "get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
# add each cross-signing signature to the correct device in the result dict.
- for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ for user_id, key_id, device_id, signature in cross_sigs_result:
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
@@ -283,9 +283,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
log_kv(result)
return result
- def _get_e2e_device_keys_txn(
+ async def _get_e2e_device_keys(
self,
- txn: LoggingTransaction,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
@@ -309,7 +308,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
# devices.
user_list = []
user_device_list = []
- for (user_id, device_id) in query_list:
+ for user_id, device_id in query_list:
if device_id is None:
user_list.append(user_id)
else:
@@ -317,7 +316,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if user_list:
user_id_in_list_clause, user_args = make_in_list_sql_clause(
- txn.database_engine, "user_id", user_list
+ self.database_engine, "user_id", user_list
)
query_clauses.append(user_id_in_list_clause)
query_params_list.append(user_args)
@@ -330,13 +329,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_device_id_in_list_clause,
user_device_args,
) = make_tuple_in_list_sql_clause(
- txn.database_engine, ("user_id", "device_id"), user_device_batch
+ self.database_engine, ("user_id", "device_id"), user_device_batch
)
query_clauses.append(user_device_id_in_list_clause)
query_params_list.append(user_device_args)
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
- for query_clause, query_params in zip(query_clauses, query_params_list):
+
+ def get_e2e_device_keys_txn(
+ txn: LoggingTransaction, query_clause: str, query_params: list
+ ) -> None:
sql = (
"SELECT user_id, device_id, "
" d.display_name, "
@@ -351,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn.execute(sql, query_params)
- for (user_id, device_id, display_name, key_json) in txn:
+ for user_id, device_id, display_name, key_json in txn:
assert device_id is not None
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
@@ -359,6 +361,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
display_name, db_to_json(key_json) if key_json else None
)
+ for query_clause, query_params in zip(query_clauses, query_params_list):
+ await self.db_pool.runInteraction(
+ "_get_e2e_device_keys",
+ get_e2e_device_keys_txn,
+ query_clause,
+ query_params,
+ )
+
if include_deleted_devices:
for user_id, device_id in deleted_devices:
if device_id is None:
@@ -380,7 +390,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signature_query_clauses = []
signature_query_params = []
- for (user_id, device_id) in device_query:
+ for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
@@ -691,7 +701,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
- ) -> List[str]:
+ ) -> Sequence[str]:
"""Returns the fallback key types that have an unused key.
Args:
@@ -731,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -744,7 +754,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+ ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -765,7 +775,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# The `Optional` comes from the `@cachedList` decorator.
- return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+ return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
@@ -924,7 +934,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+ ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -940,11 +950,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = await self.db_pool.runInteraction(
- "get_e2e_cross_signing_signatures",
- self._get_e2e_cross_signing_signatures_txn,
- result,
- from_user_id,
+ result = cast(
+ Dict[str, Optional[Mapping[str, JsonDict]]],
+ await self.db_pool.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ ),
)
return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+ async def get_max_depth_of(
+ self, event_ids: Collection[str]
+ ) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
- async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+ async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> List[str]:
+ ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
- ) -> List[str]:
+ ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -1609,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
latest_events: List[str],
limit: int,
) -> List[str]:
-
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
event_results: List[str] = []
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a0c370fde..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -203,11 +203,18 @@ class RoomNotifCounts:
# Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts]
+ @staticmethod
+ def empty() -> "RoomNotifCounts":
+ return _EMPTY_ROOM_NOTIF_COUNTS
+
def __len__(self) -> int:
# To properly account for the amount of space in any caches.
return len(self.threads) + 1
+_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
+
+
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..a8a4ed4436 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
import itertools
import logging
from collections import OrderedDict
-from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -26,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
- Sequence,
Set,
Tuple,
)
@@ -36,7 +34,7 @@ from prometheus_client import Counter
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
@@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
from synapse.util.stringutils import non_null_str_or_none
@@ -72,24 +70,6 @@ event_counter = Counter(
)
-class PartialStateConflictError(SynapseError):
- """An internal error raised when attempting to persist an event with partial state
- after the room containing the event has been un-partial stated.
-
- This error should be handled by recomputing the event context and trying again.
-
- This error has an HTTP status code so that it can be transported over replication.
- It should not be exposed to clients.
- """
-
- def __init__(self) -> None:
- super().__init__(
- HTTPStatus.CONFLICT,
- msg="Cannot persist partial state event in un-partial stated room",
- errcode=Codes.UNKNOWN,
- )
-
-
@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -306,7 +286,7 @@ class PersistEventsStore:
# The set of event_ids to return. This includes all soft-failed events
# and their prev events.
- existing_prevs = set()
+ existing_prevs: Set[str] = set()
def _get_prevs_before_rejected_txn(
txn: LoggingTransaction, batch: Collection[str]
@@ -489,7 +469,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events: List[EventBase],
) -> None:
-
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
@@ -571,7 +550,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, Sequence[str]],
+ event_to_auth_chain: Dict[str, StrCollection],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -865,7 +844,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, Sequence[str]],
+ event_to_auth_chain: Dict[str, StrCollection],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@@ -2045,10 +2024,6 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,)
)
- if rel_type == RelationTypes.ANNOTATION:
- self.store._invalidate_cache_and_stream(
- txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
- )
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_references_for_event, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..daef3685b0 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
import attr
@@ -29,7 +29,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows = 0
last_row_event_id = ""
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
- cast(Dict[str, Sequence[str]], event_to_auth_chain),
+ cast(Dict[str, StrCollection], event_to_auth_chain),
)
return _CalculateChainCover(
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
except Exception as e:
@@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
- txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
- )
- self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a9259fe446..0d90407ebd 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._stream_id_gen: AbstractStreamIdTracker
- self._backfill_id_gen: AbstractStreamIdTracker
+ self._stream_id_gen: AbstractStreamIdGenerator
+ self._backfill_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(redactions_sql + clause, args)
- for (redacter, redacted) in txn:
+ for redacter, redacted in txn:
d = event_dict.get(redacted)
if d:
d.redactions.append(redacter)
@@ -1779,7 +1778,7 @@ class EventsWorkerStore(SQLBaseStore):
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
- "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type,"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
" e.outlier"
" FROM events AS e"
@@ -1791,10 +1790,10 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
- " WHERE ? < event_stream_ordering"
- " AND event_stream_ordering <= ?"
+ " WHERE ? < out.event_stream_ordering"
+ " AND out.event_stream_ordering <= ?"
" AND out.instance_name = ?"
- " ORDER BY event_stream_ordering ASC"
+ " ORDER BY out.event_stream_ordering ASC"
)
txn.execute(sql, (last_id, current_id, instance_name))
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
return db_to_json(def_json)
-
-class FilteringStore(FilteringWorkerStore):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
return filter_id
- return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ attempts = 0
+ while True:
+ # Try a few times.
+ # This is technically needed if a user tries to create two filters at once,
+ # leading to two concurrent transactions.
+ # The failure case would be:
+ # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+ # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+ # - INSERT INTO ... → both transactions insert filter_id = 6
+ # One of the transactions will commit. The other will get a unique key
+ # constraint violation error (IntegrityError). This is not the same as a
+ # serialisability violation, which would be automatically retried by
+ # `runInteraction`.
+ try:
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+
+ if attempts >= 5:
+ raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
-
# Set ordering
order_by_column = MediaSortOrder(order_by).value
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
- async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+ async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 9213ce0b5a..7a7c0d9c75 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -325,6 +325,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# We then run the same purge a second time without this isolation level to
# purge any of those rows which were added during the first.
+ logger.info("[purge] Starting initial main purge of [1/2]")
state_groups_to_delete = await self.db_pool.runInteraction(
"purge_room",
self._purge_room_txn,
@@ -332,6 +333,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
isolation_level=IsolationLevel.READ_COMMITTED,
)
+ logger.info("[purge] Starting secondary main purge of [2/2]")
state_groups_to_delete.extend(
await self.db_pool.runInteraction(
"purge_room",
@@ -339,6 +341,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
room_id=room_id,
),
)
+ logger.info("[purge] Done with main purge")
return state_groups_to_delete
@@ -376,7 +379,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
referenced_chain_id_tuples = list(txn)
- logger.info("[purge] removing events from event_auth_chain_links")
+ logger.info("[purge] removing from event_auth_chain_links")
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
@@ -399,7 +402,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"rejections",
"state_events",
):
- logger.info("[purge] removing %s from %s", room_id, table)
+ logger.info("[purge] removing from %s", table)
txn.execute(
"""
@@ -420,12 +423,14 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_push_actions",
"event_search",
"event_failed_pull_attempts",
+ # Note: the partial state tables have foreign keys between each other, and to
+ # `events` and `rooms`. We need to delete from them in the right order.
"partial_state_events",
+ "partial_state_rooms_servers",
+ "partial_state_rooms",
"events",
"federation_inbound_events_staging",
"local_current_membership",
- "partial_state_rooms_servers",
- "partial_state_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
@@ -452,7 +457,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# happy
"rooms",
):
- logger.info("[purge] removing %s from %s", room_id, table)
+ logger.info("[purge] removing from %s", table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
# Other tables we do NOT need to clear out:
@@ -484,6 +489,4 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# that already exist.
self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
- logger.info("[purge] done")
-
return state_groups
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9b2bbe060d..9f862f00c1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
IdGenerator,
StreamIdGenerator,
)
@@ -118,7 +117,7 @@ class PushRulesWorkerStore(
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ self._push_rules_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..9a24f7a655 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -36,7 +36,6 @@ from synapse.storage.database import (
)
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict
@@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ self._pushers_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"pushers",
@@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
@@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, access_token FROM pushers AS p
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..074942b167 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
+ Sequence,
Tuple,
cast,
)
@@ -37,7 +39,7 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import (
- AbstractStreamIdTracker,
+ AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -63,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._receipts_id_gen: AbstractStreamIdTracker
+ self._receipts_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> Sequence[JsonDict]:
"""Get receipts for a single room for sending to clients.
Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[JsonDict]:
+ ) -> Sequence[JsonDict]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, Sequence[JsonDict]]:
if not room_ids:
return {}
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, JsonDict]:
+ ) -> Mapping[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -766,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)
- async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self._insert_linearized_receipt_txn,
@@ -885,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
def _populate_receipt_event_stream_ordering_txn(
txn: LoggingTransaction,
) -> bool:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
import logging
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import attr
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
- async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="user_delete_threepid",
)
- async def user_delete_threepids(self, user_id: str) -> None:
- """Delete all threepid this user has bound
-
- Args:
- user_id: The user id to delete all threepids of
-
- """
- await self.db_pool.simple_delete(
- "user_threepids",
- keyvalues={"user_id": user_id},
- desc="user_delete_threepids",
- )
-
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
) -> None:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..bc3a83919c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+ ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -397,141 +398,6 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None
@cached()
- async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
- )
- async def get_aggregation_groups_for_events(
- self, event_ids: Collection[str]
- ) -> Mapping[str, Optional[List[JsonDict]]]:
- """Get a list of annotations on the given events, grouped by event type and
- aggregation key, sorted by count.
-
- This is used e.g. to get the what and how many reactions have happend
- on an event.
-
- Args:
- event_ids: Fetch events that relate to these event IDs.
-
- Returns:
- A map of event IDs to a list of groups of annotations that match.
- Each entry is a dict with `type`, `key` and `count` fields.
- """
- # The number of entries to return per event ID.
- limit = 5
-
- clause, args = make_in_list_sql_clause(
- self.database_engine, "relates_to_id", event_ids
- )
- args.append(RelationTypes.ANNOTATION)
-
- sql = f"""
- SELECT
- relates_to_id,
- annotation.type,
- aggregation_key,
- COUNT(DISTINCT annotation.sender)
- FROM events AS annotation
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS parent ON
- parent.event_id = relates_to_id
- AND parent.room_id = annotation.room_id
- WHERE
- {clause}
- AND relation_type = ?
- GROUP BY relates_to_id, annotation.type, aggregation_key
- ORDER BY relates_to_id, COUNT(*) DESC
- """
-
- def _get_aggregation_groups_for_events_txn(
- txn: LoggingTransaction,
- ) -> Mapping[str, List[JsonDict]]:
- txn.execute(sql, args)
-
- result: Dict[str, List[JsonDict]] = {}
- for event_id, type, key, count in cast(
- List[Tuple[str, str, str, int]], txn
- ):
- event_results = result.setdefault(event_id, [])
-
- # Limit the number of results per event ID.
- if len(event_results) == limit:
- continue
-
- event_results.append({"type": type, "key": key, "count": count})
-
- return result
-
- return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
- )
-
- async def get_aggregation_groups_for_users(
- self, event_ids: Collection[str], users: FrozenSet[str]
- ) -> Dict[str, Dict[Tuple[str, str], int]]:
- """Fetch the partial aggregations for an event for specific users.
-
- This is used, in conjunction with get_aggregation_groups_for_event, to
- remove information from the results for ignored users.
-
- Args:
- event_ids: Fetch events that relate to these event IDs.
- users: The users to fetch information for.
-
- Returns:
- A map of event ID to a map of (event type, aggregation key) to a
- count of users.
- """
-
- if not users:
- return {}
-
- events_sql, args = make_in_list_sql_clause(
- self.database_engine, "relates_to_id", event_ids
- )
-
- users_sql, users_args = make_in_list_sql_clause(
- self.database_engine, "annotation.sender", users
- )
- args.extend(users_args)
- args.append(RelationTypes.ANNOTATION)
-
- sql = f"""
- SELECT
- relates_to_id,
- annotation.type,
- aggregation_key,
- COUNT(DISTINCT annotation.sender)
- FROM events AS annotation
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS parent ON
- parent.event_id = relates_to_id
- AND parent.room_id = annotation.room_id
- WHERE {events_sql} AND {users_sql} AND relation_type = ?
- GROUP BY relates_to_id, annotation.type, aggregation_key
- ORDER BY relates_to_id, COUNT(*) DESC
- """
-
- def _get_aggregation_groups_for_users_txn(
- txn: LoggingTransaction,
- ) -> Dict[str, Dict[Tuple[str, str], int]]:
- txn.execute(sql, args)
-
- result: Dict[str, Dict[Tuple[str, str], int]] = {}
- for event_id, type, key, count in cast(
- List[Tuple[str, str, str, int]], txn
- ):
- result.setdefault(event_id, {})[(type, key)] = count
-
- return result
-
- return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
- )
-
- @cached()
async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..3825bd6079 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_un_partial_stated_rooms_from_stream_txn,
)
+ async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+ """Retrieve an event report
+
+ Args:
+ report_id: ID of reported event in database
+ Returns:
+ JSON dict of information from an event report or None if the
+ report does not exist.
+ """
+
+ def _get_event_report_txn(
+ txn: LoggingTransaction, report_id: int
+ ) -> Optional[Dict[str, Any]]:
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.content,
+ events.sender,
+ room_stats_state.canonical_alias,
+ room_stats_state.name,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ JOIN room_stats_state
+ ON room_stats_state.room_id = er.room_id
+ WHERE er.id = ?
+ """
+
+ txn.execute(sql, [report_id])
+ row = txn.fetchone()
+
+ if not row:
+ return None
+
+ event_report = {
+ "id": row[0],
+ "received_ts": row[1],
+ "room_id": row[2],
+ "event_id": row[3],
+ "user_id": row[4],
+ "score": db_to_json(row[5]).get("score"),
+ "reason": db_to_json(row[5]).get("reason"),
+ "sender": row[6],
+ "canonical_alias": row[7],
+ "name": row[8],
+ "event_json": db_to_json(row[9]),
+ }
+
+ return event_report
+
+ return await self.db_pool.runInteraction(
+ "get_event_report", _get_event_report_txn, report_id
+ )
+
+ async def get_event_reports_paginate(
+ self,
+ start: int,
+ limit: int,
+ direction: Direction = Direction.BACKWARDS,
+ user_id: Optional[str] = None,
+ room_id: Optional[str] = None,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Retrieve a paginated list of event reports
+
+ Args:
+ start: event offset to begin the query from
+ limit: number of rows to retrieve
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards)
+ user_id: search for user_id. Ignored if user_id is None
+ room_id: search for room_id. Ignored if room_id is None
+ Returns:
+ Tuple of:
+ json list of event reports
+ total number of event reports matching the filter criteria
+ """
+
+ def _get_event_reports_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ filters = []
+ args: List[object] = []
+
+ if user_id:
+ filters.append("er.user_id LIKE ?")
+ args.extend(["%" + user_id + "%"])
+ if room_id:
+ filters.append("er.room_id LIKE ?")
+ args.extend(["%" + room_id + "%"])
+
+ if direction == Direction.BACKWARDS:
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
+ sql = """
+ SELECT COUNT(*) as total_event_reports
+ FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
+ {}
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ count = cast(Tuple[int], txn.fetchone())[0]
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.content,
+ events.sender,
+ room_stats_state.canonical_alias,
+ room_stats_state.name
+ FROM event_reports AS er
+ LEFT JOIN events
+ ON events.event_id = er.event_id
+ JOIN room_stats_state
+ ON room_stats_state.room_id = er.room_id
+ {where_clause}
+ ORDER BY er.received_ts {order}
+ LIMIT ?
+ OFFSET ?
+ """.format(
+ where_clause=where_clause,
+ order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+
+ event_reports = []
+ for row in txn:
+ try:
+ s = db_to_json(row[5]).get("score")
+ r = db_to_json(row[5]).get("reason")
+ except Exception:
+ logger.error("Unable to parse json from event_reports: %s", row[0])
+ continue
+ event_reports.append(
+ {
+ "id": row[0],
+ "received_ts": row[1],
+ "room_id": row[2],
+ "event_id": row[3],
+ "user_id": row[4],
+ "score": s,
+ "reason": r,
+ "sender": row[6],
+ "canonical_alias": row[7],
+ "name": row[8],
+ }
+ )
+
+ return event_reports, count
+
+ return await self.db_pool.runInteraction(
+ "get_event_reports_paginate", _get_event_reports_paginate_txn
+ )
+
+ async def delete_event_report(self, report_id: int) -> bool:
+ """Remove an event report from database.
+
+ Args:
+ report_id: Report to delete
+
+ Returns:
+ Whether the report was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ table="event_reports",
+ keyvalues={"id": report_id},
+ desc="delete_event_report",
+ )
+ except StoreError:
+ # Deletion failed because report does not exist
+ return False
+
+ return True
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
reason: Optional[str],
content: JsonDict,
received_ts: int,
- ) -> None:
+ ) -> int:
+ """Add an event report
+
+ Args:
+ room_id: Room that contains the reported event.
+ event_id: The reported event.
+ user_id: User who reports the event.
+ reason: Description that the user specifies.
+ content: Report request body (score and reason).
+ received_ts: Time when the user submitted the report (milliseconds).
+ Returns:
+ Id of the event report.
+ """
next_id = self._event_reports_id_gen.get_next()
await self.db_pool.simple_insert(
table="event_reports",
@@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
},
desc="add_event_report",
)
-
- async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
- """Retrieve an event report
-
- Args:
- report_id: ID of reported event in database
- Returns:
- JSON dict of information from an event report or None if the
- report does not exist.
- """
-
- def _get_event_report_txn(
- txn: LoggingTransaction, report_id: int
- ) -> Optional[Dict[str, Any]]:
-
- sql = """
- SELECT
- er.id,
- er.received_ts,
- er.room_id,
- er.event_id,
- er.user_id,
- er.content,
- events.sender,
- room_stats_state.canonical_alias,
- room_stats_state.name,
- event_json.json AS event_json
- FROM event_reports AS er
- LEFT JOIN events
- ON events.event_id = er.event_id
- JOIN event_json
- ON event_json.event_id = er.event_id
- JOIN room_stats_state
- ON room_stats_state.room_id = er.room_id
- WHERE er.id = ?
- """
-
- txn.execute(sql, [report_id])
- row = txn.fetchone()
-
- if not row:
- return None
-
- event_report = {
- "id": row[0],
- "received_ts": row[1],
- "room_id": row[2],
- "event_id": row[3],
- "user_id": row[4],
- "score": db_to_json(row[5]).get("score"),
- "reason": db_to_json(row[5]).get("reason"),
- "sender": row[6],
- "canonical_alias": row[7],
- "name": row[8],
- "event_json": db_to_json(row[9]),
- }
-
- return event_report
-
- return await self.db_pool.runInteraction(
- "get_event_report", _get_event_report_txn, report_id
- )
-
- async def get_event_reports_paginate(
- self,
- start: int,
- limit: int,
- direction: Direction = Direction.BACKWARDS,
- user_id: Optional[str] = None,
- room_id: Optional[str] = None,
- ) -> Tuple[List[Dict[str, Any]], int]:
- """Retrieve a paginated list of event reports
-
- Args:
- start: event offset to begin the query from
- limit: number of rows to retrieve
- direction: Whether to fetch the most recent first (backwards) or the
- oldest first (forwards)
- user_id: search for user_id. Ignored if user_id is None
- room_id: search for room_id. Ignored if room_id is None
- Returns:
- Tuple of:
- json list of event reports
- total number of event reports matching the filter criteria
- """
-
- def _get_event_reports_paginate_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], int]:
- filters = []
- args: List[object] = []
-
- if user_id:
- filters.append("er.user_id LIKE ?")
- args.extend(["%" + user_id + "%"])
- if room_id:
- filters.append("er.room_id LIKE ?")
- args.extend(["%" + room_id + "%"])
-
- if direction == Direction.BACKWARDS:
- order = "DESC"
- else:
- order = "ASC"
-
- where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
-
- # We join on room_stats_state despite not using any columns from it
- # because the join can influence the number of rows returned;
- # e.g. a room that doesn't have state, maybe because it was deleted.
- # The query returning the total count should be consistent with
- # the query returning the results.
- sql = """
- SELECT COUNT(*) as total_event_reports
- FROM event_reports AS er
- JOIN room_stats_state ON room_stats_state.room_id = er.room_id
- {}
- """.format(
- where_clause
- )
- txn.execute(sql, args)
- count = cast(Tuple[int], txn.fetchone())[0]
-
- sql = """
- SELECT
- er.id,
- er.received_ts,
- er.room_id,
- er.event_id,
- er.user_id,
- er.content,
- events.sender,
- room_stats_state.canonical_alias,
- room_stats_state.name
- FROM event_reports AS er
- LEFT JOIN events
- ON events.event_id = er.event_id
- JOIN room_stats_state
- ON room_stats_state.room_id = er.room_id
- {where_clause}
- ORDER BY er.received_ts {order}
- LIMIT ?
- OFFSET ?
- """.format(
- where_clause=where_clause,
- order=order,
- )
-
- args += [limit, start]
- txn.execute(sql, args)
-
- event_reports = []
- for row in txn:
- try:
- s = db_to_json(row[5]).get("score")
- r = db_to_json(row[5]).get("reason")
- except Exception:
- logger.error("Unable to parse json from event_reports: %s", row[0])
- continue
- event_reports.append(
- {
- "id": row[0],
- "received_ts": row[1],
- "room_id": row[2],
- "event_id": row[3],
- "user_id": row[4],
- "score": s,
- "reason": r,
- "sender": row[6],
- "canonical_alias": row[7],
- "name": row[8],
- }
- )
-
- return event_reports, count
-
- return await self.db_pool.runInteraction(
- "get_event_reports_paginate", _get_event_reports_paginate_txn
- )
+ return next_id
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self._known_servers_count
@cached(max_entries=100000, iterable=True)
- async def get_users_in_room(self, room_id: str) -> List[str]:
+ async def get_users_in_room(self, room_id: str) -> Sequence[str]:
"""Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
- def get_user_in_room_with_profile(
- self, room_id: str, user_id: str
- ) -> Dict[str, ProfileInfo]:
+ def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
raise NotImplementedError()
@cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000)
- async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+ async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(iterable=True)
- async def get_local_users_in_room(self, room_id: str) -> List[str]:
+ async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
"""
Retrieves a list of the current roommembers who are local to the server.
"""
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id)
- user_who_share_room = set()
+ user_who_share_room: Set[str] = set()
for room_id in room_ids:
user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
class SearchBackgroundUpdateStore(SearchWorkerStore):
-
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
-
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
from unpaddedbase64 import encode_base64
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
- def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+ def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
)
async def get_event_reference_hashes(
self, event_ids: Collection[str]
- ) -> Dict[str, Dict[str, bytes]]:
+ ) -> Mapping[str, Mapping[str, bytes]]:
"""Get all hashes for given events.
Args:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
insert_cols = []
qargs = []
- for (key, val) in chain(
+ for key, val in chain(
keyvalues.items(), absolutes.items(), additive_relatives.items()
):
insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
- async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+ async def get_tags_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for a user.
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
- ) -> Dict[str, Dict[str, JsonDict]]:
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
-
if direction == Direction.BACKWARDS:
order = "DESC"
else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,11 +14,12 @@
import logging
import re
+import unicodedata
from typing import (
TYPE_CHECKING,
- Dict,
Iterable,
List,
+ Mapping,
Optional,
Sequence,
Set,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:
-
# Get all the rooms that we want to process.
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
values={"display_name": display_name, "avatar_url": avatar_url},
)
+ # The display name that goes into the database index.
+ index_display_name = display_name
+ if index_display_name is not None:
+ index_display_name = _filter_text_for_index(index_display_name)
+
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id,
get_localpart_from_id(user_id),
get_domain_from_id(user_id),
- display_name,
+ index_display_name,
),
)
elif isinstance(self.database_engine, Sqlite3Engine):
- value = "%s %s" % (user_id, display_name) if display_name else user_id
+ value = (
+ "%s %s" % (user_id, index_display_name)
+ if index_display_name
+ else user_id
+ )
self.db_pool.simple_upsert_txn(
txn,
table="user_directory_search",
@@ -586,7 +595,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+ async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results[0:limit]}
+def _filter_text_for_index(text: str) -> str:
+ """Transforms text before it is inserted into the user directory index, or searched
+ for in the user directory index.
+
+ Note that the user directory search table needs to be rebuilt whenever this function
+ changes.
+ """
+ # Lowercase the text, to make searches case-insensitive.
+ # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+ # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+ # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+ # all.
+ text = text.lower()
+
+ # Normalize the text. NFKC normalization has two effects:
+ # 1. It canonicalizes the text, ie. maps all visually identical strings to the same
+ # string. For example, ["e", "◌́"] is mapped to ["é"].
+ # 2. It maps strings that are roughly equivalent to the same string.
+ # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+ # ["i", "9"].
+ text = unicodedata.normalize("NFKC", text)
+
+ # Note that nothing is done to make searches accent-insensitive.
+ # That could be achieved by converting to NFKD form instead (with combining accents
+ # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+ # The downside of this may be noisier search results, since search terms with
+ # explicit accents will match characters with no accents, or completely different
+ # accents.
+ #
+ # text = unicodedata.normalize("NFKD", text)
+ # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+ return text
+
+
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
+ search_term = _filter_text_for_index(search_term)
# Pull out the individual words, discarding any non-word characters.
results = _parse_words(search_term)
@@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
- results = _parse_words(search_term)
+ search_term = _filter_text_for_index(search_term)
+
+ escaped_words = []
+ for word in _parse_words(search_term):
+ # Postgres tsvector and tsquery quoting rules:
+ # words potentially containing punctuation should be quoted
+ # and then existing quotes and backslashes should be doubled
+ # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
- both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
- exact = " & ".join("%s" % (result,) for result in results)
- prefix = " & ".join("%s:*" % (result,) for result in results)
+ quoted_word = word.replace("'", "''").replace("\\", "\\\\")
+ escaped_words.append(f"'{quoted_word}'")
+
+ both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
+ exact = " & ".join("%s" % (word,) for word in escaped_words)
+ prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
return both, exact, prefix
@@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]:
if USE_ICU:
return _parse_words_with_icu(search_term)
+ return _parse_words_with_regex(search_term)
+
+
+def _parse_words_with_regex(search_term: str) -> List[str]:
+ """
+ Break down search term into words, when we don't have ICU available.
+ See: `_parse_words`
+ """
return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
-
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..29ff64e876 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
- (
- non_member_state,
- incomplete_groups_nm,
- ) = self._get_state_for_groups_using_cache(
+ non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
- (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
+ member_state, incomplete_groups_m = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
@@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
+ async def store_state_deltas_for_batched(
+ self,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+ room_id: str,
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state deltas for a group of events and contexts created to be
+ batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: the events to generate and store a state groups for
+ and their associated contexts
+ room_id: the id of the room the events were created for
+ prev_group: the state group of the last event persisted before the batched events
+ were created
+ """
+
+ def insert_deltas_group_txn(
+ txn: LoggingTransaction,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state groups for the provided events and contexts.
+
+ Requires that we have the state as a delta from the last persisted state group.
+
+ Returns:
+ A list of state groups
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ num_state_groups = sum(
+ 1 for event, _ in events_and_context if event.is_state()
+ )
+
+ state_groups = self._state_group_seq_gen.get_next_mult_txn(
+ txn, num_state_groups
+ )
+
+ sg_before = prev_group
+ state_group_iter = iter(state_groups)
+ for event, context in events_and_context:
+ if not event.is_state():
+ context.state_group_after_event = sg_before
+ context.state_group_before_event = sg_before
+ continue
+
+ sg_after = next(state_group_iter)
+ context.state_group_after_event = sg_after
+ context.state_group_before_event = sg_before
+ context.state_delta_due_to_event = {
+ (event.type, event.state_key): event.event_id
+ }
+ sg_before = sg_after
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups",
+ keys=("id", "room_id", "event_id"),
+ values=[
+ (context.state_group_after_event, room_id, event.event_id)
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_group_edges",
+ keys=("state_group", "prev_state_group"),
+ values=[
+ (
+ context.state_group_after_event,
+ context.state_group_before_event,
+ )
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (
+ context.state_group_after_event,
+ room_id,
+ key[0],
+ key[1],
+ state_id,
+ )
+ for event, context in events_and_context
+ if context.state_delta_due_to_event is not None
+ for key, state_id in context.state_delta_due_to_event.items()
+ ],
+ )
+ return events_and_context
+
+ return await self.db_pool.runInteraction(
+ "store_state_deltas_for_batched.insert_deltas_group",
+ insert_deltas_group_txn,
+ events_and_context,
+ prev_group,
+ )
+
async def store_state_group(
self,
event_id: str,
@@ -689,12 +805,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete: State groups to delete
"""
+ logger.info("[purge] Starting state purge")
await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
state_groups_to_delete,
)
+ logger.info("[purge] Done with state purge")
def _purge_room_state_txn(
self,
|