diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
- FilteringStore,
+ FilteringWorkerStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 95567826f2..308d19440f 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -237,6 +237,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
return None
+ async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+ self, user_id: str, data_type: str
+ ) -> Optional[int]:
+ """
+ Returns:
+ The stream ID of the account data,
+ or None if there is no such account data.
+ """
+
+ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
+ sql = """
+ SELECT stream_id FROM account_data
+ WHERE user_id = ? AND account_data_type = ?
+ ORDER BY stream_id DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (user_id, data_type))
+
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ else:
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_latest_stream_id_for_global_account_data_by_type_for_user",
+ get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+ )
+
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- for (user_id, messages_by_device) in edu["messages"].items():
- for (device_id, msg) in messages_by_device.items():
+ for user_id, messages_by_device in edu["messages"].items():
+ for device_id, msg in messages_by_device.items():
with start_active_span("store_outgoing_to_device_message"):
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
) -> Tuple[int, bool]:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1ca66d57d4..0dd15f16ff 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -512,7 +512,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
results.append(("org.matrix.signing_key_update", result))
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- for (user_id, edu) in results:
+ for user_id, edu in results:
issue_8631_logger.debug(
"device update to %s for %s from %s to %s: %s",
destination,
@@ -1316,7 +1316,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
"""
count = 0
- for (destination, user_id, stream_id, device_id) in rows:
+ for destination, user_id, stream_id, device_id in rows:
txn.execute(
delete_sql, (destination, user_id, stream_id, stream_id, device_id)
)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No backup with that version exists")
values = []
- for (room_id, session_id, room_key) in room_keys:
+ for room_id, session_id, room_key in room_keys:
values.append(
(
user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 2c2d145666..b9c39b1718 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# add each cross-signing signature to the correct device in the result dict.
- for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ for user_id, key_id, device_id, signature in cross_sigs_result:
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
@@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
# devices.
user_list = []
user_device_list = []
- for (user_id, device_id) in query_list:
+ for user_id, device_id in query_list:
if device_id is None:
user_list.append(user_id)
else:
@@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn.execute(sql, query_params)
- for (user_id, device_id, display_name, key_json) in txn:
+ for user_id, device_id, display_name, key_json in txn:
assert device_id is not None
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
@@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signature_query_clauses = []
signature_query_params = []
- for (user_id, device_id) in device_query:
+ for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca780cca36..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
latest_events: List[str],
limit: int,
) -> List[str]:
-
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
event_results: List[str] = []
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7996cbb557..73b8aea16c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -469,7 +469,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events: List[EventBase],
) -> None:
-
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 584536111d..0a275e6ce6 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows = 0
last_row_event_id = ""
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
except Exception as e:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d0ef10258..b7e7498125 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1493,7 +1493,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(redactions_sql + clause, args)
- for (redacter, redacted) in txn:
+ for redacter, redacted in txn:
d = event_dict.get(redacted)
if d:
d.redactions.append(redacter)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
return db_to_json(def_json)
-
-class FilteringStore(FilteringWorkerStore):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
return filter_id
- return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ attempts = 0
+ while True:
+ # Try a few times.
+ # This is technically needed if a user tries to create two filters at once,
+ # leading to two concurrent transactions.
+ # The failure case would be:
+ # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+ # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+ # - INSERT INTO ... → both transactions insert filter_id = 6
+ # One of the transactions will commit. The other will get a unique key
+ # constraint violation error (IntegrityError). This is not the same as a
+ # serialisability violation, which would be automatically retried by
+ # `runInteraction`.
+ try:
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+
+ if attempts >= 5:
+ raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
-
# Set ordering
order_by_column = MediaSortOrder(order_by).value
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..fddbc07afa 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,6 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
@@ -392,7 +391,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, access_token FROM pushers AS p
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +447,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index dddf49c2d5..92a82240ab 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
def _populate_receipt_event_stream_ordering_txn(
txn: LoggingTransaction,
) -> bool:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a55e17624..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="user_delete_threepid",
)
- async def user_delete_threepids(self, user_id: str) -> None:
- """Delete all threepid this user has bound
-
- Args:
- user_id: The user id to delete all threepids of
-
- """
- await self.db_pool.simple_delete(
- "user_threepids",
- keyvalues={"user_id": user_id},
- desc="user_delete_threepids",
- )
-
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..a2e9519cb6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_un_partial_stated_rooms_from_stream_txn,
)
+ async def delete_event_report(self, report_id: int) -> bool:
+ """Remove an event report from database.
+
+ Args:
+ report_id: Report to delete
+
+ Returns:
+ Whether the report was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ table="event_reports",
+ keyvalues={"id": report_id},
+ desc="delete_event_report",
+ )
+ except StoreError:
+ # Deletion failed because report does not exist
+ return False
+
+ return True
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2160,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
reason: Optional[str],
content: JsonDict,
received_ts: int,
- ) -> None:
+ ) -> int:
+ """Add an event report
+
+ Args:
+ room_id: Room that contains the reported event.
+ event_id: The reported event.
+ user_id: User who reports the event.
+ reason: Description that the user specifies.
+ content: Report request body (score and reason).
+ received_ts: Time when the user submitted the report (milliseconds).
+ Returns:
+ Id of the event report.
+ """
next_id = self._event_reports_id_gen.get_next()
await self.db_pool.simple_insert(
table="event_reports",
@@ -2154,6 +2187,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
},
desc="add_event_report",
)
+ return next_id
async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
"""Retrieve an event report
@@ -2168,7 +2202,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def _get_event_report_txn(
txn: LoggingTransaction, report_id: int
) -> Optional[Dict[str, Any]]:
-
sql = """
SELECT
er.id,
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
class SearchBackgroundUpdateStore(SearchWorkerStore):
-
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
-
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
insert_cols = []
qargs = []
- for (key, val) in chain(
+ for key, val in chain(
keyvalues.items(), absolutes.items(), additive_relatives.items()
):
insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
-
if direction == Direction.BACKWARDS:
order = "DESC"
else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 30af4b3b6c..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,6 +14,7 @@
import logging
import re
+import unicodedata
from typing import (
TYPE_CHECKING,
Iterable,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:
-
# Get all the rooms that we want to process.
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
values={"display_name": display_name, "avatar_url": avatar_url},
)
+ # The display name that goes into the database index.
+ index_display_name = display_name
+ if index_display_name is not None:
+ index_display_name = _filter_text_for_index(index_display_name)
+
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id,
get_localpart_from_id(user_id),
get_domain_from_id(user_id),
- display_name,
+ index_display_name,
),
)
elif isinstance(self.database_engine, Sqlite3Engine):
- value = "%s %s" % (user_id, display_name) if display_name else user_id
+ value = (
+ "%s %s" % (user_id, index_display_name)
+ if index_display_name
+ else user_id
+ )
self.db_pool.simple_upsert_txn(
txn,
table="user_directory_search",
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results[0:limit]}
+def _filter_text_for_index(text: str) -> str:
+ """Transforms text before it is inserted into the user directory index, or searched
+ for in the user directory index.
+
+ Note that the user directory search table needs to be rebuilt whenever this function
+ changes.
+ """
+ # Lowercase the text, to make searches case-insensitive.
+ # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+ # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+ # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+ # all.
+ text = text.lower()
+
+ # Normalize the text. NFKC normalization has two effects:
+ # 1. It canonicalizes the text, ie. maps all visually identical strings to the same
+ # string. For example, ["e", "◌́"] is mapped to ["é"].
+ # 2. It maps strings that are roughly equivalent to the same string.
+ # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+ # ["i", "9"].
+ text = unicodedata.normalize("NFKC", text)
+
+ # Note that nothing is done to make searches accent-insensitive.
+ # That could be achieved by converting to NFKD form instead (with combining accents
+ # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+ # The downside of this may be noisier search results, since search terms with
+ # explicit accents will match characters with no accents, or completely different
+ # accents.
+ #
+ # text = unicodedata.normalize("NFKD", text)
+ # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+ return text
+
+
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
+ search_term = _filter_text_for_index(search_term)
# Pull out the individual words, discarding any non-word characters.
results = _parse_words(search_term)
@@ -918,6 +963,8 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
+ search_term = _filter_text_for_index(search_term)
+
escaped_words = []
for word in _parse_words(search_term):
# Postgres tsvector and tsquery quoting rules:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
-
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..bf4cdfdf29 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
- (
- non_member_state,
- incomplete_groups_nm,
- ) = self._get_state_for_groups_using_cache(
+ non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
- (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
+ member_state, incomplete_groups_m = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
@@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
+ async def store_state_deltas_for_batched(
+ self,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+ room_id: str,
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state deltas for a group of events and contexts created to be
+ batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: the events to generate and store a state groups for
+ and their associated contexts
+ room_id: the id of the room the events were created for
+ prev_group: the state group of the last event persisted before the batched events
+ were created
+ """
+
+ def insert_deltas_group_txn(
+ txn: LoggingTransaction,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state groups for the provided events and contexts.
+
+ Requires that we have the state as a delta from the last persisted state group.
+
+ Returns:
+ A list of state groups
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ num_state_groups = sum(
+ 1 for event, _ in events_and_context if event.is_state()
+ )
+
+ state_groups = self._state_group_seq_gen.get_next_mult_txn(
+ txn, num_state_groups
+ )
+
+ sg_before = prev_group
+ state_group_iter = iter(state_groups)
+ for event, context in events_and_context:
+ if not event.is_state():
+ context.state_group_after_event = sg_before
+ context.state_group_before_event = sg_before
+ continue
+
+ sg_after = next(state_group_iter)
+ context.state_group_after_event = sg_after
+ context.state_group_before_event = sg_before
+ context.state_delta_due_to_event = {
+ (event.type, event.state_key): event.event_id
+ }
+ sg_before = sg_after
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups",
+ keys=("id", "room_id", "event_id"),
+ values=[
+ (context.state_group_after_event, room_id, event.event_id)
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_group_edges",
+ keys=("state_group", "prev_state_group"),
+ values=[
+ (
+ context.state_group_after_event,
+ context.state_group_before_event,
+ )
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (
+ context.state_group_after_event,
+ room_id,
+ key[0],
+ key[1],
+ state_id,
+ )
+ for event, context in events_and_context
+ if context.state_delta_due_to_event is not None
+ for key, state_id in context.state_delta_due_to_event.items()
+ ],
+ )
+ return events_and_context
+
+ return await self.db_pool.runInteraction(
+ "store_state_deltas_for_batched.insert_deltas_group",
+ insert_deltas_group_txn,
+ events_and_context,
+ prev_group,
+ )
+
async def store_state_group(
self,
event_id: str,
|