diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/database.py | 16 | ||||
-rw-r--r-- | synapse/storage/databases/__init__.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 14 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_bg_updates.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 16 | ||||
-rw-r--r-- | synapse/storage/databases/main/media_repository.py | 59 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 19 | ||||
-rw-r--r-- | synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql | 19 | ||||
-rw-r--r-- | synapse/storage/databases/main/user_directory.py | 63 | ||||
-rw-r--r-- | synapse/storage/databases/state/store.py | 10 | ||||
-rw-r--r-- | synapse/storage/util/sequence.py | 24 |
11 files changed, 195 insertions, 48 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4646926449..f1ba529a2d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection # python 3 does not have a maximum int value @@ -381,7 +380,10 @@ class DatabasePool: _TXN_ID = 0 def __init__( - self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + self, + hs, + database_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ): self.hs = hs self._clock = hs.get_clock() @@ -420,16 +422,6 @@ class DatabasePool: self._check_safe_to_upsert, ) - # We define this sequence here so that it can be referenced from both - # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): - txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] - - self.event_chain_id_gen = build_sequence_generator( - engine, get_chain_id_txn, "event_auth_chain_id" - ) - def is_running(self) -> bool: """Is the database pool currently running""" return self._db_pool.running diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index e84f8b42f7..379c78bb83 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -79,7 +79,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` if hs.get_instance_name() in hs.config.worker.writers.events: - persist_events = PersistEventsStore(hs, database, main) + persist_events = PersistEventsStore(hs, database, main, db_conn) if "state" in database_config.databases: logger.info( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 287606cb4f..cd1ceac50e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -42,7 +42,9 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry +from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -90,7 +92,11 @@ class PersistEventsStore: """ def __init__( - self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" + self, + hs: "HomeServer", + db: DatabasePool, + main_data_store: "DataStore", + db_conn: Connection, ): self.hs = hs self.db_pool = db @@ -474,6 +480,7 @@ class PersistEventsStore: self._add_chain_cover_index( txn, self.db_pool, + self.store.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -484,6 +491,7 @@ class PersistEventsStore: cls, txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -630,6 +638,7 @@ class PersistEventsStore: new_chain_tuples = cls._allocate_chain_ids( txn, db_pool, + event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -768,6 +777,7 @@ class PersistEventsStore: def _allocate_chain_ids( txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -880,7 +890,7 @@ class PersistEventsStore: chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] # Generate new chain IDs for all unallocated chain IDs. - newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn( txn, len(unallocated_chain_ids) ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 89274e75f7..c1626ccf28 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, + self.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c8850a4707..edbe42f2bf 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + # We define this sequence here so that it can be referenced from both + # the DataStore and PersistEventStore. + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self.event_chain_id_gen = build_sequence_generator( + db_conn, + database.engine, + get_chain_id_txn, + "event_auth_chain_id", + table="event_auth_chains", + id_column="chain_id", + ) + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index a0313c3ccf..274f8de595 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore @@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( ) +class MediaSortOrder(Enum): + """ + Enum to define the sorting method used when returning media with + get_local_media_by_user_paginate + """ + + MEDIA_ID = "media_id" + UPLOAD_NAME = "upload_name" + CREATED_TS = "created_ts" + LAST_ACCESS_TS = "last_access_ts" + MEDIA_LENGTH = "media_length" + MEDIA_TYPE = "media_type" + QUARANTINED_BY = "quarantined_by" + SAFE_FROM_QUARANTINE = "safe_from_quarantine" + + class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_by_user_paginate( - self, start: int, limit: int, user_id: str + self, + start: int, + limit: int, + user_id: str, + order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, + direction: str = "f", ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): start: offset in the list limit: maximum amount of media_ids to retrieve user_id: fully-qualified user id + order_by: the sort order of the returned list + direction: sort ascending or descending Returns: A paginated list of all metadata of user's media, plus the total count of all the user's media @@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn(txn): + # Set ordering + order_by_column = MediaSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + args = [user_id] sql = """ SELECT COUNT(*) as total_media @@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "safe_from_quarantine" FROM local_media_repository WHERE user_id = ? - ORDER BY created_ts DESC, media_id DESC + ORDER BY {order_by_column} {order}, media_id ASC LIMIT ? OFFSET ? - """ + """.format( + order_by_column=order_by_column, + order=order, + ) args += [limit, start] txn.execute(sql, args) @@ -344,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - await self.db_pool.simple_insert( - "local_media_repository_thumbnails", - { + await self.db_pool.simple_upsert( + table="local_media_repository_thumbnails", + keyvalues={ "media_id": media_id, "thumbnail_width": thumbnail_width, "thumbnail_height": thumbnail_height, "thumbnail_method": thumbnail_method, "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, }, + values={"thumbnail_length": thumbnail_length}, desc="store_local_thumbnail", ) @@ -498,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - await self.db_pool.simple_insert( - "remote_media_cache_thumbnails", - { + await self.db_pool.simple_upsert( + table="remote_media_cache_thumbnails", + keyvalues={ "media_origin": origin, "media_id": media_id, "thumbnail_width": thumbnail_width, "thumbnail_height": thumbnail_height, "thumbnail_method": thumbnail_method, "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - "filesystem_id": filesystem_id, }, + values={"thumbnail_length": thumbnail_length}, + insertion_values={"filesystem_id": filesystem_id}, desc="store_remote_media_thumbnail", ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d5b5507815..61a7556e56 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -23,7 +23,7 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor @@ -70,7 +70,12 @@ class TokenLookupResult: class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # call `find_max_generated_user_id_localpart` each time, which is # expensive if there are many entries. self._user_id_seq = build_sequence_generator( + db_conn, database.engine, find_max_generated_user_id_localpart, "user_id_seq", + table=None, + id_column=None, ) self._account_validity = hs.config.account_validity @@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql new file mode 100644 index 0000000000..2442eea6bc --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete all pushers associated with deleted devices. This is to clear up after +-- a bug where they weren't correctly deleted when using workers. +DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens); diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 63f88eac51..1026f321e5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def add_users_in_public_rooms( self, room_id: str, user_ids: Iterable[str] ) -> None: - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. + """Insert entries into the users_in_public_rooms table. Args: room_id @@ -556,6 +555,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._prefer_local_users_in_search = ( + hs.config.user_directory_search_prefer_local_users + ) + self._server_name = hs.config.server_name + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( @@ -665,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - @cached() async def get_shared_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: @@ -754,9 +757,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + # We allow manipulating the ranking algorithm by injecting statements + # based on config options. + additional_ordering_statements = [] + ordering_arguments = () + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)" + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + # We order by rank and then if they have profile info # The ranking algorithm is hand tweaked for "best" results. Broadly # the idea is we give a higher weight to exact matches. @@ -767,7 +785,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -787,33 +805,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 8 ) ) + %(order_case_statements)s DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, + """ % { + "where_clause": where_clause, + "order_case_statements": " ".join(additional_ordering_statements), + } + args = ( + join_args + + (full_query, exact_query, prefix_query) + + ordering_arguments + + (limit + 1,) ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + # + # Note that we need to include a comma at the end for valid SQL + statement = "user_id LIKE ? DESC," + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + sql = """ SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, + %(order_statements)s display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) + """ % { + "where_clause": where_clause, + "order_statements": " ".join(additional_ordering_statements), + } + args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b16b9905d8..e2240703a7 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return txn.fetchone()[0] self._state_group_seq_gen = build_sequence_generator( - self.database_engine, get_max_state_group_txn, "state_group_id_seq" - ) - self._state_group_seq_gen.check_consistency( - db_conn, table="state_groups", id_column="id" + db_conn, + self.database_engine, + get_max_state_group_txn, + "state_group_id_seq", + table="state_groups", + id_column="id", ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 3ea637b281..36a67e7019 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator): def build_sequence_generator( + db_conn: "LoggingDatabaseConnection", database_engine: BaseDatabaseEngine, get_first_callback: GetFirstCallbackType, sequence_name: str, + table: Optional[str], + id_column: Optional[str], + stream_name: Optional[str] = None, + positive: bool = True, ) -> SequenceGenerator: """Get the best impl of SequenceGenerator available @@ -265,8 +270,23 @@ def build_sequence_generator( get_first_callback: a callback which gets the next sequence ID. Used if we're on sqlite. sequence_name: the name of a postgres sequence to use. + table, id_column, stream_name, positive: If set then `check_consistency` + is called on the created sequence. See docstring for + `check_consistency` details. """ if isinstance(database_engine, PostgresEngine): - return PostgresSequenceGenerator(sequence_name) + seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator else: - return LocalSequenceGenerator(get_first_callback) + seq = LocalSequenceGenerator(get_first_callback) + + if table: + assert id_column + seq.check_consistency( + db_conn=db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, + ) + + return seq |