diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3d98d3f5f8..0623da9aa1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import random
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
@@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta):
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
- self.rand = random.SystemRandom()
def process_replication_rows(
self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 49c7606d51..9cce62ae6c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -67,7 +67,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
-from .transactions import TransactionStore
+from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@@ -83,7 +83,7 @@ class DataStore(
StreamStore,
ProfileStore,
PresenceStore,
- TransactionStore,
+ TransactionWorkerStore,
DirectoryStore,
KeyStore,
StateStore,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index d60010e942..074b077bef 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = LruCache(
- cache_name="client_ip_last_seen", keylen=4, max_size=50000
+ cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..fd87ba71ab 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
- async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
- cache_name="device_id_exists", keylen=2, max_size=10000
+ cache_name="device_id_exists", max_size=10000
)
async def store_device(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
num_args=1,
)
async def _get_bare_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str]
+ self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
- user_ids: List[str],
+ user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 2c823e09cf..6963bbf7f4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache = LruCache(
cache_name="*getEvent*",
- keylen=3,
max_size=hs.config.caches.event_cache_size,
)
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 0e86807834..6990f3ed1d 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
- def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index db22fab23e..6a2baa7841 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
db_conn, "presence_stream", "stream_id"
)
+ self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
@@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore):
)
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+ # Delete old rows to stop database from getting really big
+ sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+ for states in batch_iter(presence_states, 50):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", [s.user_id for s in states]
+ )
+ txn.execute(sql + clause, [stream_id] + list(args))
+
# Actually insert new rows
self.db_pool.simple_insert_many_txn(
txn,
@@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore):
],
)
- # Delete old rows to stop database from getting really big
- sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
-
- for states in batch_iter(presence_states, 50):
- clause, args = make_in_list_sql_clause(
- self.database_engine, "user_id", [s.user_id for s in states]
- )
- txn.execute(sql + clause, [stream_id] + list(args))
-
async def get_all_presence_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
@@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore):
return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ async def should_user_receive_full_presence_with_token(
+ self,
+ user_id: str,
+ from_token: int,
+ ) -> bool:
+ """Check whether the given user should receive full presence using the stream token
+ they're updating from.
+
+ Args:
+ user_id: The ID of the user to check.
+ from_token: The stream token included in their /sync token.
+
+ Returns:
+ True if the user should have full presence sent to them, False otherwise.
+ """
+
+ def _should_user_receive_full_presence_with_token_txn(txn):
+ sql = """
+ SELECT 1 FROM users_to_send_full_presence_to
+ WHERE user_id = ?
+ AND presence_stream_id >= ?
+ """
+ txn.execute(sql, (user_id, from_token))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "should_user_receive_full_presence_with_token",
+ _should_user_receive_full_presence_with_token_txn,
+ )
+
+ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+ """Adds to the list of users who should receive a full snapshot of presence
+ upon their next sync.
+
+ Args:
+ user_ids: An iterable of user IDs.
+ """
+ # Add user entries to the table, updating the presence_stream_id column if the user already
+ # exists in the table.
+ await self.db_pool.simple_upsert_many(
+ table="users_to_send_full_presence_to",
+ key_names=("user_id",),
+ key_values=[(user_id,) for user_id in user_ids],
+ value_names=("presence_stream_id",),
+ # We save the current presence stream ID token along with the user ID entry so
+ # that when a user /sync's, even if they syncing multiple times across separate
+ # devices at different times, each device will receive full presence once - when
+ # the presence stream ID in their sync token is less than the one in the table
+ # for their user ID.
+ value_values=(
+ (self._presence_id_gen.get_current_token(),) for _ in user_ids
+ ),
+ desc="add_users_to_send_full_presence_to",
+ )
+
async def get_presence_for_all_users(
self,
include_offline: bool = True,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d36b18a0e9..77e2eb27db 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
@@ -1077,7 +1078,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
- expiration_ts = self.rand.randrange(
+ expiration_ts = random.randrange(
expiration_ts - self._account_validity_startup_job_max_delta,
expiration_ts,
)
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 82335e7a9d..d211c423b2 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -16,13 +16,15 @@ import logging
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
-from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.descriptors import cached
db_binary_type = memoryview
@@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
"_TransactionRow", ("response_code", "response_json")
)
-SENTINEL = object()
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DestinationRetryTimings:
+ """The current destination retry timing info for a remote server."""
-class TransactionWorkerStore(SQLBaseStore):
+ # The first time we tried and failed to reach the remote server, in ms.
+ failure_ts: int
+
+ # The last time we tried and failed to reach the remote server, in ms.
+ retry_last_ts: int
+
+ # How long since the last time we tried to reach the remote server before
+ # trying again, in ms.
+ retry_interval: int
+
+
+class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
"_cleanup_transactions", _cleanup_transactions_txn
)
-
-class TransactionStore(TransactionWorkerStore):
- """A collection of queries for handling PDUs."""
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self._destination_retry_cache = ExpiringCache(
- cache_name="get_destination_retry_timings",
- clock=self._clock,
- expiry_ms=5 * 60 * 1000,
- )
-
async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
@@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
desc="set_received_txn_response",
)
- async def get_destination_retry_timings(self, destination):
+ @cached(max_entries=10000)
+ async def get_destination_retry_timings(
+ self,
+ destination: str,
+ ) -> Optional[DestinationRetryTimings]:
"""Gets the current retry timings (if any) for a given destination.
Args:
@@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
Otherwise a dict for the retry scheme
"""
- result = self._destination_retry_cache.get(destination, SENTINEL)
- if result is not SENTINEL:
- return result
-
result = await self.db_pool.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
)
- # We don't hugely care about race conditions between getting and
- # invalidating the cache, since we time out fairly quickly anyway.
- self._destination_retry_cache[destination] = result
return result
- def _get_destination_retry_timings(self, txn, destination):
+ def _get_destination_retry_timings(
+ self, txn, destination: str
+ ) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
+ retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
- return result
+ return DestinationRetryTimings(**result)
else:
return None
@@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
retry_interval: how long until next retry in ms
"""
- self._destination_retry_cache.pop(destination, None)
if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
@@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
@@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
},
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
async def store_destination_rooms_entries(
self,
destinations: Iterable[str],
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable
+
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
- async def are_users_erased(self, user_ids):
+ async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
"""
Checks which users in a list have requested erasure
Args:
- user_ids (iterable[str]): full user id to check
+ user_ids: full user ids to check
Returns:
- dict[str, bool]:
- for each user, whether the user has requested erasure.
+ for each user, whether the user has requested erasure.
"""
- # this serves the dual purpose of (a) making sure we can do len and
- # iterate it multiple times, and (b) avoiding duplicates.
- user_ids = tuple(set(user_ids))
-
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
diff --git a/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql
new file mode 100644
index 0000000000..07b0f53ecf
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a table that keeps track of a list of users who should, upon their next
+-- sync request, receive presence for all currently online users that they are
+-- "interested" in.
+
+-- The motivation for a DB table over an in-memory list is so that this list
+-- can be added to and retrieved from by any worker. Specifically, we don't
+-- want to duplicate work across multiple sync workers.
+
+CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to(
+ -- The user ID to send full presence to.
+ user_id TEXT PRIMARY KEY,
+ -- A presence stream ID token - the current presence stream token when the row was last upserted.
+ -- If a user calls /sync and this token is part of the update they're to receive, we also include
+ -- full user presence in the response.
+ -- This allows multiple devices for a user to receive full presence whenever they next call /sync.
+ presence_stream_id BIGINT,
+ FOREIGN KEY (user_id)
+ REFERENCES users (name)
+);
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cfafba22c5..c9dce726cb 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -540,7 +540,7 @@ class StateGroupStorage:
state_filter: The state filter used to fetch state from the database.
Returns:
- A dict from (type, state_key) -> state_event
+ A dict from (type, state_key) -> state_event_id
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
|