diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
- FilteringStore,
+ FilteringWorkerStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 95567826f2..a9843f6e17 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
):
super().__init__(database, db_conn, hs)
- # `_can_write_to_account_data` indicates whether the current worker is allowed
- # to write account data. A value of `True` implies that `_account_data_id_gen`
- # is an `AbstractStreamIdGenerator` and not just a tracker.
- self._account_data_id_gen: AbstractStreamIdTracker
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
+ self._account_data_id_gen: AbstractStreamIdGenerator
+
if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -237,6 +234,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
return None
+ async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+ self, user_id: str, data_type: str
+ ) -> Optional[int]:
+ """
+ Returns:
+ The stream ID of the account data,
+ or None if there is no such account data.
+ """
+
+ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
+ sql = """
+ SELECT stream_id FROM account_data
+ WHERE user_id = ? AND account_data_type = ?
+ ORDER BY stream_id DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (user_id, data_type))
+
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ else:
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_latest_stream_id_for_global_account_data_by_type_for_user",
+ get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+ )
+
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
@@ -527,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -554,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str
- ) -> Optional[int]:
+ ) -> int:
"""Delete the room account data for the user of a given type.
Args:
@@ -567,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
data to delete.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_room_txn(
txn: LoggingTransaction, next_id: int
@@ -606,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
next_id,
)
- if not row_updated:
- return None
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_room_account_data_for_user.invalidate((user_id,))
- self.get_account_data_for_room.invalidate((user_id, room_id))
- self.get_account_data_for_room_and_type.prefill(
- (user_id, room_id, account_data_type), {}
- )
+ if row_updated:
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_room_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type), {}
+ )
return self._account_data_id_gen.get_current_token()
@@ -632,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
@@ -722,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self,
user_id: str,
account_data_type: str,
- ) -> Optional[int]:
+ ) -> int:
"""
Delete a single piece of user account data by type.
@@ -739,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
to delete.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_user_txn(
txn: LoggingTransaction, next_id: int
@@ -809,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
next_id,
)
- if not row_updated:
- return None
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_global_account_data_for_user.invalidate((user_id,))
- self.get_global_account_data_by_type_for_user.prefill(
- (user_id, account_data_type), {}
- )
+ if row_updated:
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_global_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.prefill(
+ (user_id, account_data_type), {}
+ )
return self._account_data_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 5b66431691..096dec7f87 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
- self._attempt_to_invalidate_cache(
- "get_aggregation_groups_for_event", (relates_to,)
- )
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- for (user_id, messages_by_device) in edu["messages"].items():
- for (device_id, msg) in messages_by_device.items():
+ for user_id, messages_by_device in edu["messages"].items():
+ for device_id, msg in messages_by_device.items():
with start_active_span("store_outgoing_to_device_message"):
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
) -> Tuple[int, bool]:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1ca66d57d4..5503621ad6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -52,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
@@ -91,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ self._device_list_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
@@ -512,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
results.append(("org.matrix.signing_key_update", result))
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- for (user_id, edu) in results:
+ for user_id, edu in results:
issue_8631_logger.debug(
"device update to %s for %s from %s to %s: %s",
destination,
@@ -712,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The new stream ID.
"""
- # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
- # than DeviceWorkerStore?
- async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -1316,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
"""
count = 0
- for (destination, user_id, stream_id, device_id) in rows:
+ for destination, user_id, stream_id, device_id in rows:
txn.execute(
delete_sql, (destination, user_id, stream_id, stream_id, device_id)
)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No backup with that version exists")
values = []
- for (room_id, session_id, room_key) in room_keys:
+ for room_id, session_id, room_key in room_keys:
values.append(
(
user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 2c2d145666..b9c39b1718 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# add each cross-signing signature to the correct device in the result dict.
- for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ for user_id, key_id, device_id, signature in cross_sigs_result:
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
@@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
# devices.
user_list = []
user_device_list = []
- for (user_id, device_id) in query_list:
+ for user_id, device_id in query_list:
if device_id is None:
user_list.append(user_id)
else:
@@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn.execute(sql, query_params)
- for (user_id, device_id, display_name, key_json) in txn:
+ for user_id, device_id, display_name, key_json in txn:
assert device_id is not None
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
@@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signature_query_clauses = []
signature_query_params = []
- for (user_id, device_id) in device_query:
+ for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca780cca36..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
latest_events: List[str],
limit: int,
) -> List[str]:
-
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
event_results: List[str] = []
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7996cbb557..a8a4ed4436 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -469,7 +469,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events: List[EventBase],
) -> None:
-
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
@@ -2025,10 +2024,6 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,)
)
- if rel_type == RelationTypes.ANNOTATION:
- self.store._invalidate_cache_and_stream(
- txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
- )
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_references_for_event, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 584536111d..daef3685b0 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows = 0
last_row_event_id = ""
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
except Exception as e:
@@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
- txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
- )
- self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d0ef10258..20b7a68362 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._stream_id_gen: AbstractStreamIdTracker
- self._backfill_id_gen: AbstractStreamIdTracker
+ self._stream_id_gen: AbstractStreamIdGenerator
+ self._backfill_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(redactions_sql + clause, args)
- for (redacter, redacted) in txn:
+ for redacter, redacted in txn:
d = event_dict.get(redacted)
if d:
d.redactions.append(redacter)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
return db_to_json(def_json)
-
-class FilteringStore(FilteringWorkerStore):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
return filter_id
- return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ attempts = 0
+ while True:
+ # Try a few times.
+ # This is technically needed if a user tries to create two filters at once,
+ # leading to two concurrent transactions.
+ # The failure case would be:
+ # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+ # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+ # - INSERT INTO ... → both transactions insert filter_id = 6
+ # One of the transactions will commit. The other will get a unique key
+ # constraint violation error (IntegrityError). This is not the same as a
+ # serialisability violation, which would be automatically retried by
+ # `runInteraction`.
+ try:
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+
+ if attempts >= 5:
+ raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
-
# Set ordering
order_by_column = MediaSortOrder(order_by).value
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9b2bbe060d..9f862f00c1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
IdGenerator,
StreamIdGenerator,
)
@@ -118,7 +117,7 @@ class PushRulesWorkerStore(
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ self._push_rules_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..9a24f7a655 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -36,7 +36,6 @@ from synapse.storage.database import (
)
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict
@@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ self._pushers_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"pushers",
@@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
@@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, access_token FROM pushers AS p
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index dddf49c2d5..074942b167 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -39,7 +39,7 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import (
- AbstractStreamIdTracker,
+ AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -65,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._receipts_id_gen: AbstractStreamIdTracker
+ self._receipts_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
@@ -768,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)
- async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self._insert_linearized_receipt_txn,
@@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
def _populate_receipt_event_stream_ordering_txn(
txn: LoggingTransaction,
) -> bool:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a55e17624..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="user_delete_threepid",
)
- async def user_delete_threepids(self, user_id: str) -> None:
- """Delete all threepid this user has bound
-
- Args:
- user_id: The user id to delete all threepids of
-
- """
- await self.db_pool.simple_delete(
- "user_threepids",
- keyvalues={"user_id": user_id},
- desc="user_delete_threepids",
- )
-
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
) -> None:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fa3266c081..bc3a83919c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -398,143 +398,6 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None
@cached()
- async def get_aggregation_groups_for_event(
- self, event_id: str
- ) -> Sequence[JsonDict]:
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
- )
- async def get_aggregation_groups_for_events(
- self, event_ids: Collection[str]
- ) -> Mapping[str, Optional[List[JsonDict]]]:
- """Get a list of annotations on the given events, grouped by event type and
- aggregation key, sorted by count.
-
- This is used e.g. to get the what and how many reactions have happend
- on an event.
-
- Args:
- event_ids: Fetch events that relate to these event IDs.
-
- Returns:
- A map of event IDs to a list of groups of annotations that match.
- Each entry is a dict with `type`, `key` and `count` fields.
- """
- # The number of entries to return per event ID.
- limit = 5
-
- clause, args = make_in_list_sql_clause(
- self.database_engine, "relates_to_id", event_ids
- )
- args.append(RelationTypes.ANNOTATION)
-
- sql = f"""
- SELECT
- relates_to_id,
- annotation.type,
- aggregation_key,
- COUNT(DISTINCT annotation.sender)
- FROM events AS annotation
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS parent ON
- parent.event_id = relates_to_id
- AND parent.room_id = annotation.room_id
- WHERE
- {clause}
- AND relation_type = ?
- GROUP BY relates_to_id, annotation.type, aggregation_key
- ORDER BY relates_to_id, COUNT(*) DESC
- """
-
- def _get_aggregation_groups_for_events_txn(
- txn: LoggingTransaction,
- ) -> Mapping[str, List[JsonDict]]:
- txn.execute(sql, args)
-
- result: Dict[str, List[JsonDict]] = {}
- for event_id, type, key, count in cast(
- List[Tuple[str, str, str, int]], txn
- ):
- event_results = result.setdefault(event_id, [])
-
- # Limit the number of results per event ID.
- if len(event_results) == limit:
- continue
-
- event_results.append({"type": type, "key": key, "count": count})
-
- return result
-
- return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
- )
-
- async def get_aggregation_groups_for_users(
- self, event_ids: Collection[str], users: FrozenSet[str]
- ) -> Dict[str, Dict[Tuple[str, str], int]]:
- """Fetch the partial aggregations for an event for specific users.
-
- This is used, in conjunction with get_aggregation_groups_for_event, to
- remove information from the results for ignored users.
-
- Args:
- event_ids: Fetch events that relate to these event IDs.
- users: The users to fetch information for.
-
- Returns:
- A map of event ID to a map of (event type, aggregation key) to a
- count of users.
- """
-
- if not users:
- return {}
-
- events_sql, args = make_in_list_sql_clause(
- self.database_engine, "relates_to_id", event_ids
- )
-
- users_sql, users_args = make_in_list_sql_clause(
- self.database_engine, "annotation.sender", users
- )
- args.extend(users_args)
- args.append(RelationTypes.ANNOTATION)
-
- sql = f"""
- SELECT
- relates_to_id,
- annotation.type,
- aggregation_key,
- COUNT(DISTINCT annotation.sender)
- FROM events AS annotation
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS parent ON
- parent.event_id = relates_to_id
- AND parent.room_id = annotation.room_id
- WHERE {events_sql} AND {users_sql} AND relation_type = ?
- GROUP BY relates_to_id, annotation.type, aggregation_key
- ORDER BY relates_to_id, COUNT(*) DESC
- """
-
- def _get_aggregation_groups_for_users_txn(
- txn: LoggingTransaction,
- ) -> Dict[str, Dict[Tuple[str, str], int]]:
- txn.execute(sql, args)
-
- result: Dict[str, Dict[Tuple[str, str], int]] = {}
- for event_id, type, key, count in cast(
- List[Tuple[str, str, str, int]], txn
- ):
- result.setdefault(event_id, {})[(type, key)] = count
-
- return result
-
- return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
- )
-
- @cached()
async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..3825bd6079 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_un_partial_stated_rooms_from_stream_txn,
)
+ async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+ """Retrieve an event report
+
+ Args:
+ report_id: ID of reported event in database
+ Returns:
+ JSON dict of information from an event report or None if the
+ report does not exist.
+ """
+
+ def _get_event_report_txn(
+ txn: LoggingTransaction, report_id: int
+ ) -> Optional[Dict[str, Any]]:
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.content,
+ events.sender,
+ room_stats_state.canonical_alias,
+ room_stats_state.name,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ JOIN room_stats_state
+ ON room_stats_state.room_id = er.room_id
+ WHERE er.id = ?
+ """
+
+ txn.execute(sql, [report_id])
+ row = txn.fetchone()
+
+ if not row:
+ return None
+
+ event_report = {
+ "id": row[0],
+ "received_ts": row[1],
+ "room_id": row[2],
+ "event_id": row[3],
+ "user_id": row[4],
+ "score": db_to_json(row[5]).get("score"),
+ "reason": db_to_json(row[5]).get("reason"),
+ "sender": row[6],
+ "canonical_alias": row[7],
+ "name": row[8],
+ "event_json": db_to_json(row[9]),
+ }
+
+ return event_report
+
+ return await self.db_pool.runInteraction(
+ "get_event_report", _get_event_report_txn, report_id
+ )
+
+ async def get_event_reports_paginate(
+ self,
+ start: int,
+ limit: int,
+ direction: Direction = Direction.BACKWARDS,
+ user_id: Optional[str] = None,
+ room_id: Optional[str] = None,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Retrieve a paginated list of event reports
+
+ Args:
+ start: event offset to begin the query from
+ limit: number of rows to retrieve
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards)
+ user_id: search for user_id. Ignored if user_id is None
+ room_id: search for room_id. Ignored if room_id is None
+ Returns:
+ Tuple of:
+ json list of event reports
+ total number of event reports matching the filter criteria
+ """
+
+ def _get_event_reports_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ filters = []
+ args: List[object] = []
+
+ if user_id:
+ filters.append("er.user_id LIKE ?")
+ args.extend(["%" + user_id + "%"])
+ if room_id:
+ filters.append("er.room_id LIKE ?")
+ args.extend(["%" + room_id + "%"])
+
+ if direction == Direction.BACKWARDS:
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
+ sql = """
+ SELECT COUNT(*) as total_event_reports
+ FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
+ {}
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ count = cast(Tuple[int], txn.fetchone())[0]
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.content,
+ events.sender,
+ room_stats_state.canonical_alias,
+ room_stats_state.name
+ FROM event_reports AS er
+ LEFT JOIN events
+ ON events.event_id = er.event_id
+ JOIN room_stats_state
+ ON room_stats_state.room_id = er.room_id
+ {where_clause}
+ ORDER BY er.received_ts {order}
+ LIMIT ?
+ OFFSET ?
+ """.format(
+ where_clause=where_clause,
+ order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+
+ event_reports = []
+ for row in txn:
+ try:
+ s = db_to_json(row[5]).get("score")
+ r = db_to_json(row[5]).get("reason")
+ except Exception:
+ logger.error("Unable to parse json from event_reports: %s", row[0])
+ continue
+ event_reports.append(
+ {
+ "id": row[0],
+ "received_ts": row[1],
+ "room_id": row[2],
+ "event_id": row[3],
+ "user_id": row[4],
+ "score": s,
+ "reason": r,
+ "sender": row[6],
+ "canonical_alias": row[7],
+ "name": row[8],
+ }
+ )
+
+ return event_reports, count
+
+ return await self.db_pool.runInteraction(
+ "get_event_reports_paginate", _get_event_reports_paginate_txn
+ )
+
+ async def delete_event_report(self, report_id: int) -> bool:
+ """Remove an event report from database.
+
+ Args:
+ report_id: Report to delete
+
+ Returns:
+ Whether the report was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ table="event_reports",
+ keyvalues={"id": report_id},
+ desc="delete_event_report",
+ )
+ except StoreError:
+ # Deletion failed because report does not exist
+ return False
+
+ return True
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
reason: Optional[str],
content: JsonDict,
received_ts: int,
- ) -> None:
+ ) -> int:
+ """Add an event report
+
+ Args:
+ room_id: Room that contains the reported event.
+ event_id: The reported event.
+ user_id: User who reports the event.
+ reason: Description that the user specifies.
+ content: Report request body (score and reason).
+ received_ts: Time when the user submitted the report (milliseconds).
+ Returns:
+ Id of the event report.
+ """
next_id = self._event_reports_id_gen.get_next()
await self.db_pool.simple_insert(
table="event_reports",
@@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
},
desc="add_event_report",
)
-
- async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
- """Retrieve an event report
-
- Args:
- report_id: ID of reported event in database
- Returns:
- JSON dict of information from an event report or None if the
- report does not exist.
- """
-
- def _get_event_report_txn(
- txn: LoggingTransaction, report_id: int
- ) -> Optional[Dict[str, Any]]:
-
- sql = """
- SELECT
- er.id,
- er.received_ts,
- er.room_id,
- er.event_id,
- er.user_id,
- er.content,
- events.sender,
- room_stats_state.canonical_alias,
- room_stats_state.name,
- event_json.json AS event_json
- FROM event_reports AS er
- LEFT JOIN events
- ON events.event_id = er.event_id
- JOIN event_json
- ON event_json.event_id = er.event_id
- JOIN room_stats_state
- ON room_stats_state.room_id = er.room_id
- WHERE er.id = ?
- """
-
- txn.execute(sql, [report_id])
- row = txn.fetchone()
-
- if not row:
- return None
-
- event_report = {
- "id": row[0],
- "received_ts": row[1],
- "room_id": row[2],
- "event_id": row[3],
- "user_id": row[4],
- "score": db_to_json(row[5]).get("score"),
- "reason": db_to_json(row[5]).get("reason"),
- "sender": row[6],
- "canonical_alias": row[7],
- "name": row[8],
- "event_json": db_to_json(row[9]),
- }
-
- return event_report
-
- return await self.db_pool.runInteraction(
- "get_event_report", _get_event_report_txn, report_id
- )
-
- async def get_event_reports_paginate(
- self,
- start: int,
- limit: int,
- direction: Direction = Direction.BACKWARDS,
- user_id: Optional[str] = None,
- room_id: Optional[str] = None,
- ) -> Tuple[List[Dict[str, Any]], int]:
- """Retrieve a paginated list of event reports
-
- Args:
- start: event offset to begin the query from
- limit: number of rows to retrieve
- direction: Whether to fetch the most recent first (backwards) or the
- oldest first (forwards)
- user_id: search for user_id. Ignored if user_id is None
- room_id: search for room_id. Ignored if room_id is None
- Returns:
- Tuple of:
- json list of event reports
- total number of event reports matching the filter criteria
- """
-
- def _get_event_reports_paginate_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], int]:
- filters = []
- args: List[object] = []
-
- if user_id:
- filters.append("er.user_id LIKE ?")
- args.extend(["%" + user_id + "%"])
- if room_id:
- filters.append("er.room_id LIKE ?")
- args.extend(["%" + room_id + "%"])
-
- if direction == Direction.BACKWARDS:
- order = "DESC"
- else:
- order = "ASC"
-
- where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
-
- # We join on room_stats_state despite not using any columns from it
- # because the join can influence the number of rows returned;
- # e.g. a room that doesn't have state, maybe because it was deleted.
- # The query returning the total count should be consistent with
- # the query returning the results.
- sql = """
- SELECT COUNT(*) as total_event_reports
- FROM event_reports AS er
- JOIN room_stats_state ON room_stats_state.room_id = er.room_id
- {}
- """.format(
- where_clause
- )
- txn.execute(sql, args)
- count = cast(Tuple[int], txn.fetchone())[0]
-
- sql = """
- SELECT
- er.id,
- er.received_ts,
- er.room_id,
- er.event_id,
- er.user_id,
- er.content,
- events.sender,
- room_stats_state.canonical_alias,
- room_stats_state.name
- FROM event_reports AS er
- LEFT JOIN events
- ON events.event_id = er.event_id
- JOIN room_stats_state
- ON room_stats_state.room_id = er.room_id
- {where_clause}
- ORDER BY er.received_ts {order}
- LIMIT ?
- OFFSET ?
- """.format(
- where_clause=where_clause,
- order=order,
- )
-
- args += [limit, start]
- txn.execute(sql, args)
-
- event_reports = []
- for row in txn:
- try:
- s = db_to_json(row[5]).get("score")
- r = db_to_json(row[5]).get("reason")
- except Exception:
- logger.error("Unable to parse json from event_reports: %s", row[0])
- continue
- event_reports.append(
- {
- "id": row[0],
- "received_ts": row[1],
- "room_id": row[2],
- "event_id": row[3],
- "user_id": row[4],
- "score": s,
- "reason": r,
- "sender": row[6],
- "canonical_alias": row[7],
- "name": row[8],
- }
- )
-
- return event_reports, count
-
- return await self.db_pool.runInteraction(
- "get_event_reports_paginate", _get_event_reports_paginate_txn
- )
+ return next_id
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked.
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
class SearchBackgroundUpdateStore(SearchWorkerStore):
-
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
-
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
insert_cols = []
qargs = []
- for (key, val) in chain(
+ for key, val in chain(
keyvalues.items(), absolutes.items(), additive_relatives.items()
):
insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
-
if direction == Direction.BACKWARDS:
order = "DESC"
else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f6a6fd4079..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,6 +14,7 @@
import logging
import re
+import unicodedata
from typing import (
TYPE_CHECKING,
Iterable,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:
-
# Get all the rooms that we want to process.
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
values={"display_name": display_name, "avatar_url": avatar_url},
)
+ # The display name that goes into the database index.
+ index_display_name = display_name
+ if index_display_name is not None:
+ index_display_name = _filter_text_for_index(index_display_name)
+
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id,
get_localpart_from_id(user_id),
get_domain_from_id(user_id),
- display_name,
+ index_display_name,
),
)
elif isinstance(self.database_engine, Sqlite3Engine):
- value = "%s %s" % (user_id, display_name) if display_name else user_id
+ value = (
+ "%s %s" % (user_id, index_display_name)
+ if index_display_name
+ else user_id
+ )
self.db_pool.simple_upsert_txn(
txn,
table="user_directory_search",
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results[0:limit]}
+def _filter_text_for_index(text: str) -> str:
+ """Transforms text before it is inserted into the user directory index, or searched
+ for in the user directory index.
+
+ Note that the user directory search table needs to be rebuilt whenever this function
+ changes.
+ """
+ # Lowercase the text, to make searches case-insensitive.
+ # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+ # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+ # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+ # all.
+ text = text.lower()
+
+ # Normalize the text. NFKC normalization has two effects:
+ # 1. It canonicalizes the text, ie. maps all visually identical strings to the same
+ # string. For example, ["e", "◌́"] is mapped to ["é"].
+ # 2. It maps strings that are roughly equivalent to the same string.
+ # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+ # ["i", "9"].
+ text = unicodedata.normalize("NFKC", text)
+
+ # Note that nothing is done to make searches accent-insensitive.
+ # That could be achieved by converting to NFKD form instead (with combining accents
+ # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+ # The downside of this may be noisier search results, since search terms with
+ # explicit accents will match characters with no accents, or completely different
+ # accents.
+ #
+ # text = unicodedata.normalize("NFKD", text)
+ # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+ return text
+
+
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
+ search_term = _filter_text_for_index(search_term)
# Pull out the individual words, discarding any non-word characters.
results = _parse_words(search_term)
@@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
- results = _parse_words(search_term)
+ search_term = _filter_text_for_index(search_term)
+
+ escaped_words = []
+ for word in _parse_words(search_term):
+ # Postgres tsvector and tsquery quoting rules:
+ # words potentially containing punctuation should be quoted
+ # and then existing quotes and backslashes should be doubled
+ # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
- both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
- exact = " & ".join("%s" % (result,) for result in results)
- prefix = " & ".join("%s:*" % (result,) for result in results)
+ quoted_word = word.replace("'", "''").replace("\\", "\\\\")
+ escaped_words.append(f"'{quoted_word}'")
+
+ both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
+ exact = " & ".join("%s" % (word,) for word in escaped_words)
+ prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
return both, exact, prefix
@@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]:
if USE_ICU:
return _parse_words_with_icu(search_term)
+ return _parse_words_with_regex(search_term)
+
+
+def _parse_words_with_regex(search_term: str) -> List[str]:
+ """
+ Break down search term into words, when we don't have ICU available.
+ See: `_parse_words`
+ """
return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|