diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index e84f8b42f7..379c78bb83 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -79,7 +79,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events:
- persist_events = PersistEventsStore(hs, database, main)
+ persist_events = PersistEventsStore(hs, database, main, db_conn)
if "state" in database_config.databases:
logger.info(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 287606cb4f..cd1ceac50e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -42,7 +42,9 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -90,7 +92,11 @@ class PersistEventsStore:
"""
def __init__(
- self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
+ self,
+ hs: "HomeServer",
+ db: DatabasePool,
+ main_data_store: "DataStore",
+ db_conn: Connection,
):
self.hs = hs
self.db_pool = db
@@ -474,6 +480,7 @@ class PersistEventsStore:
self._add_chain_cover_index(
txn,
self.db_pool,
+ self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
@@ -484,6 +491,7 @@ class PersistEventsStore:
cls,
txn,
db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
@@ -630,6 +638,7 @@ class PersistEventsStore:
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
+ event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
@@ -768,6 +777,7 @@ class PersistEventsStore:
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
@@ -880,7 +890,7 @@ class PersistEventsStore:
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
- newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 89274e75f7..c1626ccf28 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
+ self.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c8850a4707..edbe42f2bf 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ # We define this sequence here so that it can be referenced from both
+ # the DataStore and PersistEventStore.
+ def get_chain_id_txn(txn):
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+ return txn.fetchone()[0]
+
+ self.event_chain_id_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_chain_id_txn,
+ "event_auth_chain_id",
+ table="event_auth_chains",
+ id_column="chain_id",
+ )
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index a0313c3ccf..274f8de595 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
@@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
)
+class MediaSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning media with
+ get_local_media_by_user_paginate
+ """
+
+ MEDIA_ID = "media_id"
+ UPLOAD_NAME = "upload_name"
+ CREATED_TS = "created_ts"
+ LAST_ACCESS_TS = "last_access_ts"
+ MEDIA_LENGTH = "media_length"
+ MEDIA_TYPE = "media_type"
+ QUARANTINED_BY = "quarantined_by"
+ SAFE_FROM_QUARANTINE = "safe_from_quarantine"
+
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_by_user_paginate(
- self, start: int, limit: int, user_id: str
+ self,
+ start: int,
+ limit: int,
+ user_id: str,
+ order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
+ direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: offset in the list
limit: maximum amount of media_ids to retrieve
user_id: fully-qualified user id
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
Returns:
A paginated list of all metadata of user's media,
plus the total count of all the user's media
@@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(txn):
+ # Set ordering
+ order_by_column = MediaSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
args = [user_id]
sql = """
SELECT COUNT(*) as total_media
@@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"safe_from_quarantine"
FROM local_media_repository
WHERE user_id = ?
- ORDER BY created_ts DESC, media_id DESC
+ ORDER BY {order_by_column} {order}, media_id ASC
LIMIT ? OFFSET ?
- """
+ """.format(
+ order_by_column=order_by_column,
+ order=order,
+ )
args += [limit, start]
txn.execute(sql, args)
@@ -344,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- await self.db_pool.simple_insert(
- "local_media_repository_thumbnails",
- {
+ await self.db_pool.simple_upsert(
+ table="local_media_repository_thumbnails",
+ keyvalues={
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
- "thumbnail_length": thumbnail_length,
},
+ values={"thumbnail_length": thumbnail_length},
desc="store_local_thumbnail",
)
@@ -498,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- await self.db_pool.simple_insert(
- "remote_media_cache_thumbnails",
- {
+ await self.db_pool.simple_upsert(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
- "thumbnail_length": thumbnail_length,
- "filesystem_id": filesystem_id,
},
+ values={"thumbnail_length": thumbnail_length},
+ insertion_values={"filesystem_id": filesystem_id},
desc="store_remote_media_thumbnail",
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d5b5507815..61a7556e56 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,7 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
@@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
+ db_conn,
database.engine,
find_max_generated_user_id_localpart,
"user_id_seq",
+ table=None,
+ id_column=None,
)
self._account_validity = hs.config.account_validity
@@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
new file mode 100644
index 0000000000..2442eea6bc
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
@@ -0,0 +1,19 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Delete all pushers associated with deleted devices. This is to clear up after
+-- a bug where they weren't correctly deleted when using workers.
+DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 63f88eac51..1026f321e5 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str]
) -> None:
- """Insert entries into the users_who_share_private_rooms table. The first
- user should be a local user.
+ """Insert entries into the users_in_public_rooms table.
Args:
room_id
@@ -556,6 +555,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
+ self._prefer_local_users_in_search = (
+ hs.config.user_directory_search_prefer_local_users
+ )
+ self._server_name = hs.config.server_name
+
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
@@ -665,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- @cached()
async def get_shared_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
@@ -754,9 +757,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
+ # We allow manipulating the ranking algorithm by injecting statements
+ # based on config options.
+ additional_ordering_statements = []
+ ordering_arguments = ()
+
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+ # If enabled, this config option will rank local users higher than those on
+ # remote instances.
+ if self._prefer_local_users_in_search:
+ # This statement checks whether a given user's user ID contains a server name
+ # that matches the local server
+ statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)"
+ additional_ordering_statements.append(statement)
+
+ ordering_arguments += ("%:" + self._server_name,)
+
# We order by rank and then if they have profile info
# The ranking algorithm is hand tweaked for "best" results. Broadly
# the idea is we give a higher weight to exact matches.
@@ -767,7 +785,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
WHERE
- %s
+ %(where_clause)s
AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -787,33 +805,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
8
)
)
+ %(order_case_statements)s
DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (
- where_clause,
+ """ % {
+ "where_clause": where_clause,
+ "order_case_statements": " ".join(additional_ordering_statements),
+ }
+ args = (
+ join_args
+ + (full_query, exact_query, prefix_query)
+ + ordering_arguments
+ + (limit + 1,)
)
- args = join_args + (full_query, exact_query, prefix_query, limit + 1)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
+ # If enabled, this config option will rank local users higher than those on
+ # remote instances.
+ if self._prefer_local_users_in_search:
+ # This statement checks whether a given user's user ID contains a server name
+ # that matches the local server
+ #
+ # Note that we need to include a comma at the end for valid SQL
+ statement = "user_id LIKE ? DESC,"
+ additional_ordering_statements.append(statement)
+
+ ordering_arguments += ("%:" + self._server_name,)
+
sql = """
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
WHERE
- %s
+ %(where_clause)s
AND value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
+ %(order_statements)s
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (
- where_clause,
- )
- args = join_args + (search_query, limit + 1)
+ """ % {
+ "where_clause": where_clause,
+ "order_statements": " ".join(additional_ordering_statements),
+ }
+ args = join_args + (search_query,) + ordering_arguments + (limit + 1,)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index b16b9905d8..e2240703a7 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator(
- self.database_engine, get_max_state_group_txn, "state_group_id_seq"
- )
- self._state_group_seq_gen.check_consistency(
- db_conn, table="state_groups", id_column="id"
+ db_conn,
+ self.database_engine,
+ get_max_state_group_txn,
+ "state_group_id_seq",
+ table="state_groups",
+ id_column="id",
)
@cached(max_entries=10000, iterable=True)
|