diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab49d227de..2b196ded1b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta):
"""
try:
- if key is None:
- getattr(self, cache_name).invalidate_all()
- else:
- getattr(self, cache_name).invalidate(tuple(key))
+ cache = getattr(self, cache_name)
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
- pass
+ return
+
+ if key is None:
+ cache.invalidate_all()
+ else:
+ cache.invalidate(tuple(key))
def db_to_json(db_content):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6116191b16..d1b5760c2c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
overload,
)
+import attr
from prometheus_client import Histogram
from typing_extensions import Literal
@@ -87,16 +88,25 @@ def make_pool(
"""Get the connection pool for the database.
"""
+ # By default enable `cp_reconnect`. We need to fiddle with db_args in case
+ # someone has explicitly set `cp_reconnect`.
+ db_args = dict(db_config.config.get("args", {}))
+ db_args.setdefault("cp_reconnect", True)
+
return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
- cp_openfun=engine.on_new_connection,
- **db_config.config.get("args", {})
+ cp_openfun=lambda conn: engine.on_new_connection(
+ LoggingDatabaseConnection(conn, engine, "on_new_connection")
+ ),
+ **db_args,
)
def make_conn(
- db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
+ default_txn_name: str,
) -> Connection:
"""Make a new connection to the database and return it.
@@ -109,11 +119,60 @@ def make_conn(
for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_")
}
- db_conn = engine.module.connect(**db_params)
+ native_db_conn = engine.module.connect(**db_params)
+ db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
engine.on_new_connection(db_conn)
return db_conn
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+ """A wrapper around a database connection that returns `LoggingTransaction`
+ as its cursor class.
+
+ This is mainly used on startup to ensure that queries get logged correctly
+ """
+
+ conn = attr.ib(type=Connection)
+ engine = attr.ib(type=BaseDatabaseEngine)
+ default_txn_name = attr.ib(type=str)
+
+ def cursor(
+ self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+ ) -> "LoggingTransaction":
+ if not txn_name:
+ txn_name = self.default_txn_name
+
+ return LoggingTransaction(
+ self.conn.cursor(),
+ name=txn_name,
+ database_engine=self.engine,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
+ )
+
+ def close(self) -> None:
+ self.conn.close()
+
+ def commit(self) -> None:
+ self.conn.commit()
+
+ def rollback(self, *args, **kwargs) -> None:
+ self.conn.rollback(*args, **kwargs)
+
+ def __enter__(self) -> "Connection":
+ self.conn.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+ return self.conn.__exit__(exc_type, exc_value, traceback)
+
+ # Proxy through any unknown lookups to the DB conn class.
+ def __getattr__(self, name):
+ return getattr(self.conn, name)
+
+
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +306,12 @@ class LoggingTransaction:
def close(self) -> None:
self.txn.close()
+ def __enter__(self) -> "LoggingTransaction":
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
class PerformanceCounters:
def __init__(self):
@@ -395,7 +460,7 @@ class DatabasePool:
def new_transaction(
self,
- conn: Connection,
+ conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
@@ -436,12 +501,10 @@ class DatabasePool:
i = 0
N = 5
while True:
- cursor = LoggingTransaction(
- conn.cursor(),
- name,
- self.engine,
- after_callbacks,
- exception_callbacks,
+ cursor = conn.cursor(
+ txn_name=name,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
)
try:
r = func(cursor, *args, **kwargs)
@@ -574,7 +637,7 @@ class DatabasePool:
func,
*args,
db_autocommit=db_autocommit,
- **kwargs
+ **kwargs,
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -638,7 +701,10 @@ class DatabasePool:
if db_autocommit:
self.engine.attempt_to_set_autocommit(conn, True)
- return func(conn, *args, **kwargs)
+ db_conn = LoggingDatabaseConnection(
+ conn, self.engine, "runWithConnection"
+ )
+ return func(db_conn, *args, **kwargs)
finally:
if db_autocommit:
self.engine.attempt_to_set_autocommit(conn, False)
@@ -832,6 +898,12 @@ class DatabasePool:
attempts = 0
while True:
try:
+ # We can autocommit if we are going to use native upserts
+ autocommit = (
+ self.engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ )
+
return await self.runInteraction(
desc,
self.simple_upsert_txn,
@@ -840,6 +912,7 @@ class DatabasePool:
values,
insertion_values,
lock=lock,
+ db_autocommit=autocommit,
)
except self.engine.module.IntegrityError as e:
attempts += 1
@@ -1002,6 +1075,43 @@ class DatabasePool:
)
txn.execute(sql, list(allvalues.values()))
+ async def simple_upsert_many(
+ self,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ desc: str,
+ ) -> None:
+ """
+ Upsert, many times.
+
+ Args:
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
+ """
+
+ # We can autocommit if we are going to use native upserts
+ autocommit = (
+ self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
+ )
+
+ return await self.runInteraction(
+ desc,
+ self.simple_upsert_many_txn,
+ table,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ db_autocommit=autocommit,
+ )
+
def simple_upsert_many_txn(
self,
txn: LoggingTransaction,
@@ -1153,7 +1263,13 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
return await self.runInteraction(
- desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+ desc,
+ self.simple_select_one_txn,
+ table,
+ keyvalues,
+ retcols,
+ allow_none,
+ db_autocommit=True,
)
@overload
@@ -1204,6 +1320,7 @@ class DatabasePool:
keyvalues,
retcol,
allow_none=allow_none,
+ db_autocommit=True,
)
@overload
@@ -1285,7 +1402,12 @@ class DatabasePool:
Results in a list
"""
return await self.runInteraction(
- desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+ desc,
+ self.simple_select_onecol_txn,
+ table,
+ keyvalues,
+ retcol,
+ db_autocommit=True,
)
async def simple_select_list(
@@ -1310,7 +1432,12 @@ class DatabasePool:
A list of dictionaries.
"""
return await self.runInteraction(
- desc, self.simple_select_list_txn, table, keyvalues, retcols
+ desc,
+ self.simple_select_list_txn,
+ table,
+ keyvalues,
+ retcols,
+ db_autocommit=True,
)
@classmethod
@@ -1389,6 +1516,7 @@ class DatabasePool:
chunk,
keyvalues,
retcols,
+ db_autocommit=True,
)
results.extend(rows)
@@ -1487,7 +1615,12 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
- desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+ desc,
+ self.simple_update_one_txn,
+ table,
+ keyvalues,
+ updatevalues,
+ db_autocommit=True,
)
@classmethod
@@ -1546,7 +1679,9 @@ class DatabasePool:
keyvalues: dict of column names and values to select the row with
desc: description of the transaction, for logging and metrics
"""
- await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+ await self.runInteraction(
+ desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+ )
@staticmethod
def simple_delete_one_txn(
@@ -1585,7 +1720,9 @@ class DatabasePool:
Returns:
The number of deleted rows.
"""
- return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+ return await self.runInteraction(
+ desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
+ )
@staticmethod
def simple_delete_txn(
@@ -1633,7 +1770,13 @@ class DatabasePool:
Number rows deleted
"""
return await self.runInteraction(
- desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+ desc,
+ self.simple_delete_many_txn,
+ table,
+ column,
+ iterable,
+ keyvalues,
+ db_autocommit=True,
)
@staticmethod
@@ -1678,7 +1821,7 @@ class DatabasePool:
def get_cache_dict(
self,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
table: str,
entity_column: str,
stream_column: str,
@@ -1699,9 +1842,7 @@ class DatabasePool:
"limit": limit,
}
- sql = self.engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="get_cache_dict")
txn.execute(sql, (int(max_value),))
cache = {row[0]: int(row[1]) for row in txn}
@@ -1801,7 +1942,13 @@ class DatabasePool:
"""
return await self.runInteraction(
- desc, self.simple_search_list_txn, table, term, col, retcols
+ desc,
+ self.simple_search_list_txn,
+ table,
+ term,
+ col,
+ retcols,
+ db_autocommit=True,
)
@classmethod
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
db_name = database_config.name
engine = create_engine(database_config.config)
- with make_conn(database_config, engine) as db_conn:
+ with make_conn(database_config, engine, "startup") as db_conn:
logger.info("[database config %r]: Checking database server", db_name)
engine.check_database(db_conn)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0cb12f4c61..43660ec4fb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import calendar
import logging
-import time
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
@@ -148,7 +146,6 @@ class DataStore(
db_conn, "e2e_cross_signing_keys", "stream_id"
)
- self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
@@ -268,9 +265,6 @@ class DataStore(
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
- # Used in _generate_user_daily_visits to keep track of progress
- self._last_user_visit_update = self._get_start_of_day()
-
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -289,7 +283,6 @@ class DataStore(
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
- sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
@@ -301,192 +294,6 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- async def count_daily_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 24 hours.
- """
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return await self.db_pool.runInteraction(
- "count_daily_users", self._count_users, yesterday
- )
-
- async def count_monthly_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 30 days.
- Note this method is intended for phonehome metrics only and is different
- from the mau figure in synapse.storage.monthly_active_users which,
- amongst other things, includes a 3 day grace period before a user counts.
- """
- thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return await self.db_pool.runInteraction(
- "count_monthly_users", self._count_users, thirty_days_ago
- )
-
- def _count_users(self, txn, time_from):
- """
- Returns number of users seen in the past time_from period
- """
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
- txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
- return count
-
- async def count_r30_users(self) -> Dict[str, int]:
- """
- Counts the number of 30 day retained users, defined as:-
- * Users who have created their accounts more than 30 days ago
- * Where last seen at most 30 days ago
- * Where account creation and last_seen are > 30 days apart
-
- Returns:
- A mapping of counts globally as well as broken out by platform.
- """
-
- def _count_r30_users(txn):
- thirty_days_in_secs = 86400 * 30
- now = int(self._clock.time())
- thirty_days_ago_in_secs = now - thirty_days_in_secs
-
- sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
- SELECT
- users.name, platform, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen,
- CASE
- WHEN user_agent LIKE '%%Android%%' THEN 'android'
- WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
- WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
- WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
- WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
- ELSE 'unknown'
- END
- AS platform
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND users.appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, platform, users.creation_ts
- ) u GROUP BY platform
- """
-
- results = {}
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- for row in txn:
- if row[0] == "unknown":
- pass
- results[row[0]] = row[1]
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT users.name, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, users.creation_ts
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- (count,) = txn.fetchone()
- results["all"] = count
-
- return results
-
- return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
- def _get_start_of_day(self):
- """
- Returns millisecond unixtime for start of UTC day.
- """
- now = time.gmtime()
- today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
- return today_start * 1000
-
- async def generate_user_daily_visits(self) -> None:
- """
- Generates daily visit data for use in cohort/ retention analysis
- """
-
- def _generate_user_daily_visits(txn):
- logger.info("Calling _generate_user_daily_visits")
- today_start = self._get_start_of_day()
- a_day_in_milliseconds = 24 * 60 * 60 * 1000
- now = self.clock.time_msec()
-
- sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
- FROM user_ips AS u
- LEFT JOIN (
- SELECT user_id, device_id, timestamp FROM user_daily_visits
- WHERE timestamp = ?
- ) udv
- ON u.user_id = udv.user_id AND u.device_id=udv.device_id
- INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
- AND udv.timestamp IS NULL AND users.is_guest=0
- AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
- """
-
- # This means that the day has rolled over but there could still
- # be entries from the previous day. There is an edge case
- # where if the user logs in at 23:59 and overwrites their
- # last_seen at 00:01 then they will not be counted in the
- # previous day's stats - it is important that the query is run
- # often to minimise this case.
- if today_start > self._last_user_visit_update:
- yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(
- sql,
- (
- yesterday_start,
- yesterday_start,
- self._last_user_visit_update,
- today_start,
- ),
- )
- self._last_user_visit_update = today_start
-
- txn.execute(
- sql, (today_start, today_start, self._last_user_visit_update, now)
- )
- # Update _last_user_visit_update to now. The reason to do this
- # rather just clamping to the beginning of the day is to limit
- # the size of the join - meaning that the query can be run more
- # frequently
- self._last_user_visit_update = now
-
- await self.db_pool.runInteraction(
- "generate_user_daily_visits", _generate_user_daily_visits
- )
-
async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef81d73573..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
import logging
from typing import Dict, List, Optional, Tuple
+from synapse.api.constants import AccountDataTypes
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list",
+ AccountDataTypes.IGNORED_USER_LIST,
ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
return False
- return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ try:
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ except TypeError:
+ # The type of the ignored_users field is invalid.
+ return False
class AccountDataStore(AccountDataWorkerStore):
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 85f6b1e3fd..e550cbc866 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,18 +15,31 @@
# limitations under the License.
import logging
import re
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
-from synapse.appservice import AppServiceTransaction
+from synapse.appservice import (
+ ApplicationService,
+ ApplicationServiceState,
+ AppServiceTransaction,
+)
from synapse.config.appservice import load_appservices
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+ services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
@@ -36,17 +49,19 @@ def _make_exclusive_regex(services_cache):
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
- exclusive_user_regex = re.compile(exclusive_user_regex)
+ exclusive_user_pattern = re.compile(
+ exclusive_user_regex
+ ) # type: Optional[Pattern]
else:
# We handle this case specially otherwise the constructed regex
# will always match
- exclusive_user_regex = None
+ exclusive_user_pattern = None
- return exclusive_user_regex
+ return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -57,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
- def get_if_app_services_interested_in_user(self, user_id):
+ def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
if self.exclusive_user_regex:
@@ -65,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else:
return False
- def get_app_service_by_user_id(self, user_id):
+ def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
@@ -74,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service.
Args:
- user_id(str): The user ID to see if it is an application service.
+ user_id: The user ID to see if it is an application service.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.sender == user_id:
return service
return None
- def get_app_service_by_token(self, token):
+ def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token.
Args:
- token (str): The application service token.
+ token: The application service token.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.token == token:
return service
return None
- def get_app_service_by_id(self, as_id):
+ def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID.
Args:
- as_id (str): The application service ID.
+ as_id: The application service ID.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.id == as_id:
@@ -121,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
- async def get_appservices_by_state(self, state):
+ async def get_appservices_by_state(
+ self, state: ApplicationServiceState
+ ) -> List[ApplicationService]:
"""Get a list of application services based on their state.
Args:
- state(ApplicationServiceState): The state to filter on.
+ state: The state to filter on.
Returns:
A list of ApplicationServices, which may be empty.
"""
@@ -142,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
- async def get_appservice_state(self, service):
+ async def get_appservice_state(
+ self, service: ApplicationService
+ ) -> Optional[ApplicationServiceState]:
"""Get the application service state.
Args:
- service(ApplicationService): The service whose state to set.
+ service: The service whose state to set.
Returns:
- An ApplicationServiceState.
+ An ApplicationServiceState or none.
"""
result = await self.db_pool.simple_select_one(
"application_services_state",
@@ -161,26 +180,36 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- async def set_appservice_state(self, service, state) -> None:
+ async def set_appservice_state(
+ self, service: ApplicationService, state: ApplicationServiceState
+ ) -> None:
"""Set the application service state.
Args:
- service(ApplicationService): The service whose state to set.
- state(ApplicationServiceState): The connectivity state to apply.
+ service: The service whose state to set.
+ state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
- async def create_appservice_txn(self, service, events):
+ async def create_appservice_txn(
+ self,
+ service: ApplicationService,
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ ) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
- with the given list of events.
+ with the given list of events. Ephemeral events are NOT persisted to the
+ database and are not resent if a transaction is retried.
Args:
- service(ApplicationService): The service who the transaction is for.
- events(list<Event>): A list of events to put in the transaction.
+ service: The service who the transaction is for.
+ events: A list of persistent events to put in the transaction.
+ ephemeral: A list of ephemeral events to put in the transaction.
+
Returns:
- AppServiceTransaction: A new transaction.
+ A new transaction.
"""
def _create_appservice_txn(txn):
@@ -207,19 +236,22 @@ class ApplicationServiceTransactionWorkerStore(
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids),
)
- return AppServiceTransaction(service=service, id=new_txn_id, events=events)
+ return AppServiceTransaction(
+ service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+ )
return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
- async def complete_appservice_txn(self, txn_id, service) -> None:
+ async def complete_appservice_txn(
+ self, txn_id: int, service: ApplicationService
+ ) -> None:
"""Completes an application service transaction.
Args:
- txn_id(str): The transaction ID being completed.
- service(ApplicationService): The application service which was sent
- this transaction.
+ txn_id: The transaction ID being completed.
+ service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
@@ -259,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn
)
- async def get_oldest_unsent_txn(self, service):
- """Get the oldest transaction which has not been sent for this
- service.
+ async def get_oldest_unsent_txn(
+ self, service: ApplicationService
+ ) -> Optional[AppServiceTransaction]:
+ """Get the oldest transaction which has not been sent for this service.
Args:
- service(ApplicationService): The app service to get the oldest txn.
+ service: The app service to get the oldest txn.
Returns:
An AppServiceTransaction or None.
"""
@@ -296,9 +329,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
+ return AppServiceTransaction(
+ service=service, id=entry["txn_id"], events=events, ephemeral=[]
+ )
- def _get_last_txn(self, txn, service_id):
+ def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
@@ -309,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
- async def set_appservice_last_pos(self, pos) -> None:
+ async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -319,8 +354,10 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
- async def get_new_events_for_appservice(self, current_id, limit):
- """Get all new evnets"""
+ async def get_new_events_for_appservice(
+ self, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
+ """Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
sql = (
@@ -351,6 +388,54 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events
+ async def get_type_stream_id_for_appservice(
+ self, service: ApplicationService, type: str
+ ) -> int:
+ if type not in ("read_receipt", "presence"):
+ raise ValueError(
+ "Expected type to be a valid application stream id type, got %s"
+ % (type,)
+ )
+
+ def get_type_stream_id_for_appservice_txn(txn):
+ stream_id_type = "%s_stream_id" % type
+ txn.execute(
+ # We do NOT want to escape `stream_id_type`.
+ "SELECT %s FROM application_services_state WHERE as_id=?"
+ % stream_id_type,
+ (service.id,),
+ )
+ last_stream_id = txn.fetchone()
+ if last_stream_id is None or last_stream_id[0] is None: # no row exists
+ return 0
+ else:
+ return int(last_stream_id[0])
+
+ return await self.db_pool.runInteraction(
+ "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
+ )
+
+ async def set_type_stream_id_for_appservice(
+ self, service: ApplicationService, type: str, pos: Optional[int]
+ ) -> None:
+ if type not in ("read_receipt", "presence"):
+ raise ValueError(
+ "Expected type to be a valid application stream id type, got %s"
+ % (type,)
+ )
+
+ def set_type_stream_id_for_appservice_txn(txn):
+ stream_id_type = "%s_stream_id" % type
+ txn.execute(
+ "UPDATE application_services_state SET %s = ? WHERE as_id=?"
+ % stream_id_type,
+ (pos, service.id),
+ )
+
+ await self.db_pool.runInteraction(
+ "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+ )
+
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..3e26d5ba87 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -17,12 +17,12 @@ import logging
from typing import TYPE_CHECKING
from synapse.events.utils import prune_event_dict
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- def _censor_redactions():
- return run_as_background_process(
- "_censor_redactions", self._censor_redactions
- )
-
- if self.hs.config.redaction_retention_period is not None:
- hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+ if (
+ hs.config.run_background_tasks
+ and self.hs.config.redaction_retention_period is not None
+ ):
+ hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
+ @wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self):
"""Censors all redactions older than the configured period that haven't
been censored yet.
@@ -105,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = encode_json(
+ pruned_json = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -171,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(
+ pruned_json = json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 239c7a949c..339bd691a4 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -351,16 +351,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
+
+ if hs.config.run_background_tasks and self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.db_pool.updates.has_completed_background_update(
+ "devices_last_seen"
+ ):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
+ )
+ """
- self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.db_pool.runInteraction(
+ "_prune_old_user_ips", _prune_old_user_ips_txn
)
- super().__init__(database, db_conn, hs)
- self.user_ips_max_age = hs.config.user_ips_max_age
+class ClientIpStore(ClientIpWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+
+ self.client_ip_last_seen = LruCache(
+ cache_name="client_ip_last_seen", keylen=4, max_size=50000
+ )
+
+ super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
- if self.user_ips_max_age:
- self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
@@ -391,7 +442,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
- self.client_ip_last_seen.prefill(key, now)
+ self.client_ip_last_seen.set(key, now)
self._batch_row_update[key] = (user_agent, device_id, now)
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]
-
- @wrap_as_background_process("prune_old_user_ips")
- async def _prune_old_user_ips(self):
- """Removes entries in user IPs older than the configured period.
- """
-
- if self.user_ips_max_age is None:
- # Nothing to do
- return
-
- if not await self.db_pool.updates.has_completed_background_update(
- "devices_last_seen"
- ):
- # Only start pruning if we have finished populating the devices
- # last seen info.
- return
-
- # We do a slightly funky SQL delete to ensure we don't try and delete
- # too much at once (as the table may be very large from before we
- # started pruning).
- #
- # This works by finding the max last_seen that is less than the given
- # time, but has no more than N rows before it, deleting all rows with
- # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
- # returns exactly one row).
- sql = """
- DELETE FROM user_ips
- WHERE last_seen <= (
- SELECT COALESCE(MAX(last_seen), -1)
- FROM (
- SELECT last_seen FROM user_ips
- WHERE last_seen <= ?
- ORDER BY last_seen ASC
- LIMIT 5000
- ) AS u
- )
- """
-
- timestamp = self.clock.time_msec() - self.user_ips_max_age
-
- def _prune_old_user_ips_txn(txn):
- txn.execute(sql, (timestamp,))
-
- await self.db_pool.runInteraction(
- "_prune_old_user_ips", _prune_old_user_ips_txn
- )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..dfb4f87b8f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
trace,
whitelisted_homeserver,
)
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -33,8 +33,9 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
-from synapse.util.caches.descriptors import Cache, cached, cachedList
+from synapse.util import json_decoder, json_encoder
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -48,6 +49,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(
+ self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+ )
+
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -698,6 +707,172 @@ class DeviceWorkerStore(SQLBaseStore):
_mark_remote_user_device_list_as_unsubscribed_txn,
)
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ # FIXME: make sure device ID still exists in devices table
+ row = await self.db_pool.simple_select_one(
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcols=["device_id", "device_data"],
+ allow_none=True,
+ )
+ return (
+ (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+ )
+
+ def _store_dehydrated_device_txn(
+ self, txn, user_id: str, device_id: str, device_data: str
+ ) -> Optional[str]:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcol="device_id",
+ allow_none=True,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ values={"device_id": device_id, "device_data": device_data},
+ )
+ return old_device_id
+
+ async def store_dehydrated_device(
+ self, user_id: str, device_id: str, device_data: JsonDict
+ ) -> Optional[str]:
+ """Store a dehydrated device for a user.
+
+ Args:
+ user_id: the user that we are storing the device for
+ device_id: the ID of the dehydrated device
+ device_data: the dehydrated device information
+ Returns:
+ device id of the user's previous dehydrated device, if any
+ """
+ return await self.db_pool.runInteraction(
+ "store_dehydrated_device_txn",
+ self._store_dehydrated_device_txn,
+ user_id,
+ device_id,
+ json_encoder.encode(device_data),
+ )
+
+ async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+ """Remove a dehydrated device.
+
+ Args:
+ user_id: the user that the dehydrated device belongs to
+ device_id: the ID of the dehydrated device
+ """
+ count = await self.db_pool.simple_delete(
+ "dehydrated_devices",
+ {"user_id": user_id, "device_id": device_id},
+ desc="remove_dehydrated_device",
+ )
+ return count >= 1
+
+ @wrap_as_background_process("prune_old_outbound_device_pokes")
+ async def _prune_old_outbound_device_pokes(
+ self, prune_age: int = 24 * 60 * 60 * 1000
+ ) -> None:
+ """Delete old entries out of the device_lists_outbound_pokes to ensure
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
+ """
+ yesterday = self._clock.time_msec() - prune_age
+
+ def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
+ select_sql = """
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
+ GROUP BY destination, user_id
+ HAVING min(ts) < ? AND count(*) > 1
+ """
+
+ txn.execute(select_sql, (yesterday,))
+ rows = txn.fetchall()
+
+ if not rows:
+ return
+
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
+ delete_sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
+ """
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
+
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+ logger.info("Pruned %d device list outbound pokes", count)
+
+ await self.db_pool.runInteraction(
+ "_prune_old_outbound_device_pokes", _prune_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -830,14 +1005,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
- self.device_id_exists_cache = Cache(
- name="device_id_exists", keylen=2, max_entries=10000
+ self.device_id_exists_cache = LruCache(
+ cache_name="device_id_exists", keylen=2, max_size=10000
)
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: str
+ self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -879,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
- self.device_id_exists_cache.prefill(key, True)
+ self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
raise
@@ -955,7 +1128,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def update_remote_device_list_cache_entry(
- self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: str
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
@@ -983,7 +1156,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
content: JsonDict,
- stream_id: int,
+ stream_id: str,
) -> None:
if content.get("deleted"):
self.db_pool.simple_delete_txn(
@@ -1193,95 +1366,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids
],
)
-
- def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
- """Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers.
-
- Normally, we try to send device updates as a delta since a previous known point:
- this is done by setting the prev_id in the m.device_list_update EDU. However,
- for that to work, we have to have a complete record of each change to
- each device, which can add up to quite a lot of data.
-
- An alternative mechanism is that, if the remote server sees that it has missed
- an entry in the stream_id sequence for a given user, it will request a full
- list of that user's devices. Hence, we can reduce the amount of data we have to
- store (and transmit in some future transaction), by clearing almost everything
- for a given destination out of the database, and having the remote server
- resync.
-
- All we need to do is make sure we keep at least one row for each
- (user, destination) pair, to remind us to send a m.device_list_update EDU for
- that user when the destination comes back. It doesn't matter which device
- we keep.
- """
- yesterday = self._clock.time_msec() - prune_age
-
- def _prune_txn(txn):
- # look for (user, destination) pairs which have an update older than
- # the cutoff.
- #
- # For each pair, we also need to know the most recent stream_id, and
- # an arbitrary device_id at that stream_id.
- select_sql = """
- SELECT
- dlop1.destination,
- dlop1.user_id,
- MAX(dlop1.stream_id) AS stream_id,
- (SELECT MIN(dlop2.device_id) AS device_id FROM
- device_lists_outbound_pokes dlop2
- WHERE dlop2.destination = dlop1.destination AND
- dlop2.user_id=dlop1.user_id AND
- dlop2.stream_id=MAX(dlop1.stream_id)
- )
- FROM device_lists_outbound_pokes dlop1
- GROUP BY destination, user_id
- HAVING min(ts) < ? AND count(*) > 1
- """
-
- txn.execute(select_sql, (yesterday,))
- rows = txn.fetchall()
-
- if not rows:
- return
-
- logger.info(
- "Pruning old outbound device list updates for %i users/destinations: %s",
- len(rows),
- shortstr((row[0], row[1]) for row in rows),
- )
-
- # we want to keep the update with the highest stream_id for each user.
- #
- # there might be more than one update (with different device_ids) with the
- # same stream_id, so we also delete all but one rows with the max stream id.
- delete_sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND (
- stream_id < ? OR
- (stream_id = ? AND device_id != ?)
- )
- """
- count = 0
- for (destination, user_id, stream_id, device_id) in rows:
- txn.execute(
- delete_sql, (destination, user_id, stream_id, stream_id, device_id)
- )
- count += txn.rowcount
-
- # Since we've deleted unsent deltas, we need to remove the entry
- # of last successful sent so that the prev_ids are correctly set.
- sql = """
- DELETE FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
- logger.info("Pruned %d device list outbound pokes", count)
-
- return run_as_background_process(
- "prune_old_outbound_device_pokes",
- self.db_pool.runInteraction,
- "_prune_old_outbound_device_pokes",
- _prune_txn,
- )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..4d1b92d1aa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -33,6 +33,7 @@ from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
+ from synapse.server import HomeServer
@attr.s(slots=True)
@@ -47,7 +48,20 @@ class DeviceKeyLookupResult:
keys = attr.ib(type=Optional[JsonDict])
-class EndToEndKeyWorkerStore(SQLBaseStore):
+class EndToEndKeyBackgroundStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "e2e_cross_signing_keys_idx",
+ index_name="e2e_cross_signing_keys_stream_idx",
+ table="e2e_cross_signing_keys",
+ columns=["stream_id"],
+ unique=True,
+ )
+
+
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
@@ -367,6 +381,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+ async def set_e2e_fallback_keys(
+ self, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
+ # fallback_keys will usually only have one item in it, so using a for
+ # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+ # FIXME: make sure that only one key per algorithm is uploaded
+ for key_id, fallback_key in fallback_keys.items():
+ algorithm, key_id = key_id.split(":", 1)
+ await self.db_pool.simple_upsert(
+ "e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ desc="set_e2e_fallback_key",
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ @cached(max_entries=10000)
+ async def get_e2e_unused_fallback_key_types(
+ self, user_id: str, device_id: str
+ ) -> List[str]:
+ """Returns the fallback key types that have an unused key.
+
+ Args:
+ user_id: the user whose keys are being queried
+ device_id: the device whose keys are being queried
+
+ Returns:
+ a list of key types
+ """
+ return await self.db_pool.simple_select_onecol(
+ "e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+ retcol="algorithm",
+ desc="get_e2e_unused_fallback_key_types",
+ )
+
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
@@ -701,15 +770,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
+ fallback_sql = (
+ "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
result = {}
delete = []
+ used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
+ otk_row = txn.fetchone()
+ if otk_row is not None:
+ key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
+ else:
+ # no one-time key available, so see if there's a fallback
+ # key
+ txn.execute(fallback_sql, (user_id, device_id, algorithm))
+ fallback_row = txn.fetchone()
+ if fallback_row is not None:
+ key_id, key_json, used = fallback_row
+ device_result[algorithm + ":" + key_id] = key_json
+ if not used:
+ used_fallbacks.append(
+ (user_id, device_id, algorithm, key_id)
+ )
+
+ # drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +817,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ # mark fallback keys as used
+ for user_id, device_id, algorithm, key_id in used_fallbacks:
+ self.db_pool.simple_update_txn(
+ txn,
+ "e2e_fallback_keys_json",
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ {"used": True},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
return result
return await self.db_pool.runInteraction(
@@ -754,6 +862,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..2e07c37340 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,19 +19,33 @@ from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
from synapse.events import EventBase
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ hs.get_clock().looping_call(
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ )
+
+ # Cache of event ID to list of auth event IDs and their depths.
+ self._event_auth_cache = LruCache(
+ 500000, "_event_auth_cache", size_callback=len
+ ) # type: LruCache[str, List[Tuple[str, int]]]
+
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -76,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
results = set()
- base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
+ # We pull out the depth simply so that we can populate the
+ # `_event_auth_cache` cache.
+ base_sql = """
+ SELECT a.event_id, auth_id, depth
+ FROM event_auth AS a
+ INNER JOIN events AS e ON (e.event_id = a.auth_id)
+ WHERE
+ """
front = set(event_ids)
while front:
new_front = set()
for chunk in batch_iter(front, 100):
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", chunk
- )
- txn.execute(base_sql + clause, args)
- new_front.update(r[0] for r in txn)
+ # Pull the auth events either from the cache or DB.
+ to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ for event_id in chunk:
+ res = self._event_auth_cache.get(event_id)
+ if res is None:
+ to_fetch.append(event_id)
+ else:
+ new_front.update(auth_id for auth_id, depth in res)
+
+ if to_fetch:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", to_fetch
+ )
+ txn.execute(base_sql + clause, args)
+
+ # Note we need to batch up the results by event ID before
+ # adding to the cache.
+ to_cache = {}
+ for event_id, auth_event_id, auth_event_depth in txn:
+ to_cache.setdefault(event_id, []).append(
+ (auth_event_id, auth_event_depth)
+ )
+ new_front.add(auth_event_id)
+
+ for event_id, auth_events in to_cache.items():
+ self._event_auth_cache.set(event_id, auth_events)
new_front -= results
@@ -205,14 +247,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
break
# Fetch the auth events and their depths of the N last events we're
- # currently walking
+ # currently walking, either from cache or DB.
search, chunk = search[:-100], search[-100:]
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
- )
- txn.execute(base_sql + clause, args)
- for event_id, auth_event_id, auth_event_depth in txn:
+ found = [] # Results found # type: List[Tuple[str, str, int]]
+ to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ for _, event_id in chunk:
+ res = self._event_auth_cache.get(event_id)
+ if res is None:
+ to_fetch.append(event_id)
+ else:
+ found.extend((event_id, auth_id, depth) for auth_id, depth in res)
+
+ if to_fetch:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", to_fetch
+ )
+ txn.execute(base_sql + clause, args)
+
+ # We parse the results and add the to the `found` set and the
+ # cache (note we need to batch up the results by event ID before
+ # adding to the cache).
+ to_cache = {}
+ for event_id, auth_event_id, auth_event_depth in txn:
+ to_cache.setdefault(event_id, []).append(
+ (auth_event_id, auth_event_depth)
+ )
+ found.append((event_id, auth_event_id, auth_event_depth))
+
+ for event_id, auth_events in to_cache.items():
+ self._event_auth_cache.set(event_id, auth_events)
+
+ for event_id, auth_event_id, auth_event_depth in found:
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
sets = event_to_missing_sets.get(auth_event_id)
@@ -586,6 +652,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row["event_id"] for row in rows]
+ @wrap_as_background_process("delete_old_forward_extrem_cache")
+ async def _delete_old_forward_extrem_cache(self) -> None:
+ def _delete_old_forward_extrem_cache_txn(txn):
+ # Delete entries older than a month, while making sure we don't delete
+ # the only entries for a room.
+ sql = """
+ DELETE FROM stream_ordering_to_exterm
+ WHERE
+ room_id IN (
+ SELECT room_id
+ FROM stream_ordering_to_exterm
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
+ """
+ txn.execute(
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+ )
+
+ await self.db_pool.runInteraction(
+ "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+ )
+
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
@@ -606,34 +694,6 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
- hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
- )
-
- def _delete_old_forward_extrem_cache(self):
- def _delete_old_forward_extrem_cache_txn(txn):
- # Delete entries older than a month, while making sure we don't delete
- # the only entries for a room.
- sql = """
- DELETE FROM stream_ordering_to_exterm
- WHERE
- room_id IN (
- SELECT room_id
- FROM stream_ordering_to_exterm
- WHERE stream_ordering > ?
- ) AND stream_ordering < ?
- """
- txn.execute(
- sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
- )
-
- return run_as_background_process(
- "delete_old_forward_extrem_cache",
- self.db_pool.runInteraction,
- "_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn,
- )
-
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..2e56dfaf31 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Dict, List, Optional, Tuple, Union
import attr
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.stream_ordering_month_ago = None
self.stream_ordering_day_ago = None
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- )
+ cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
cur.close()
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
+
self._rotate_delay = 3
self._rotate_count = 10000
+ self._doing_notif_rotation = False
+ if hs.config.run_background_tasks:
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
@@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Error removing push actions after event persistence failure"
)
- def _find_stream_orderings_for_times(self):
- return run_as_background_process(
- "event_push_action_stream_orderings",
- self.db_pool.runInteraction,
+ @wrap_as_background_process("event_push_action_stream_orderings")
+ async def _find_stream_orderings_for_times(self) -> None:
+ await self.db_pool.runInteraction(
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
- def _find_stream_orderings_for_times_txn(self, txn):
+ def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
logger.info("Searching for stream ordering 1 month ago")
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
@@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
return result[0] if result else None
-
-class EventPushActionsStore(EventPushActionsWorkerStore):
- EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self.db_pool.updates.register_background_index_update(
- self.EPA_HIGHLIGHT_INDEX,
- index_name="event_push_actions_u_highlight",
- table="event_push_actions",
- columns=["user_id", "stream_ordering"],
- )
-
- self.db_pool.updates.register_background_index_update(
- "event_push_actions_highlights_index",
- index_name="event_push_actions_highlights_index",
- table="event_push_actions",
- columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1",
- )
-
- self._doing_notif_rotation = False
- self._rotate_notif_loop = self._clock.looping_call(
- self._start_rotate_notifs, 30 * 60 * 1000
- )
-
- async def get_push_actions_for_user(
- self, user_id, before=None, limit=50, only_highlight=False
- ):
- def f(txn):
- before_clause = ""
- if before:
- before_clause = "AND epa.stream_ordering < ?"
- args = [user_id, before, limit]
- else:
- args = [user_id, limit]
-
- if only_highlight:
- if len(before_clause) > 0:
- before_clause += " "
- before_clause += "AND epa.highlight = 1"
-
- # NB. This assumes event_ids are globally unique since
- # it makes the query easier to index
- sql = (
- "SELECT epa.event_id, epa.room_id,"
- " epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
- " FROM event_push_actions epa, events e"
- " WHERE epa.event_id = e.event_id"
- " AND epa.user_id = ? %s"
- " AND epa.notif = 1"
- " ORDER BY epa.stream_ordering DESC"
- " LIMIT ?" % (before_clause,)
- )
- txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
-
- push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
- for pa in push_actions:
- pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- return push_actions
-
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
- def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
-
- Args:
- txn: The transcation
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id),
- )
-
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
- )
-
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
-
- def _start_rotate_notifs(self):
- return run_as_background_process("rotate_notifs", self._rotate_notifs)
-
+ @wrap_as_background_process("rotate_notifs")
async def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
@@ -958,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
+class EventPushActionsStore(EventPushActionsWorkerStore):
+ EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ self.EPA_HIGHLIGHT_INDEX,
+ index_name="event_push_actions_u_highlight",
+ table="event_push_actions",
+ columns=["user_id", "stream_ordering"],
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1",
+ )
+
+ async def get_push_actions_for_user(
+ self, user_id, before=None, limit=50, only_highlight=False
+ ):
+ def f(txn):
+ before_clause = ""
+ if before:
+ before_clause = "AND epa.stream_ordering < ?"
+ args = [user_id, before, limit]
+ else:
+ args = [user_id, limit]
+
+ if only_highlight:
+ if len(before_clause) > 0:
+ before_clause += " "
+ before_clause += "AND epa.highlight = 1"
+
+ # NB. This assumes event_ids are globally unique since
+ # it makes the query easier to index
+ sql = (
+ "SELECT epa.event_id, epa.room_id,"
+ " epa.stream_ordering, epa.topological_ordering,"
+ " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
+ " FROM event_push_actions epa, events e"
+ " WHERE epa.event_id = e.event_id"
+ " AND epa.user_id = ? %s"
+ " AND epa.notif = 1"
+ " ORDER BY epa.stream_ordering DESC"
+ " LIMIT ?" % (before_clause,)
+ )
+ txn.execute(sql, args)
+ return self.db_pool.cursor_to_dict(txn)
+
+ push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
+ for pa in push_actions:
+ pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
+ return push_actions
+
+ async def get_latest_push_action_stream_ordering(self):
+ def f(txn):
+ txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
+ return txn.fetchone()
+
+ result = await self.db_pool.runInteraction(
+ "get_latest_push_action_stream_ordering", f
+ )
+ return result[0] or 0
+
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
+ """
+ Purges old push actions for a user and room before a given
+ stream_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
+ Args:
+ txn: The transcation
+ room_id: Room ID to delete from
+ user_id: user ID to delete for
+ stream_ordering: The lowest stream ordering which will
+ not be deleted.
+ """
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id, user_id),
+ )
+
+ # We need to join on the events table to get the received_ts for
+ # event_push_actions and sqlite won't let us use a join in a delete so
+ # we can't just delete where received_ts < x. Furthermore we can
+ # only identify event_push_actions by a tuple of room_id, event_id
+ # we we can't use a subquery.
+ # Instead, we look up the stream ordering for the last event in that
+ # room received before the threshold time and delete event_push_actions
+ # in the room with a stream_odering before that.
+ txn.execute(
+ "DELETE FROM event_push_actions "
+ " WHERE user_id = ? AND room_id = ? AND "
+ " stream_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ )
+
+ txn.execute(
+ """
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
+
def _action_has_highlight(actions):
for action in actions:
try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..90fb1a1f00 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -52,16 +52,6 @@ event_counter = Counter(
)
-def encode_json(json_object):
- """
- Encode a Python object as JSON and return it in a Unicode string.
- """
- out = frozendict_json_encoder.encode(json_object)
- if isinstance(out, bytes):
- out = out.decode("utf8")
- return out
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -341,6 +331,10 @@ class PersistEventsStore:
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+ # stream orderings should have been assigned by now
+ assert min_stream_order
+ assert max_stream_order
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -367,6 +361,8 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
+ self._persist_transaction_ids_txn(txn, events_and_contexts)
+
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
@@ -411,6 +407,35 @@ class PersistEventsStore:
# room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ def _persist_transaction_ids_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ):
+ """Persist the mapping from transaction IDs to event IDs (if defined).
+ """
+
+ to_insert = []
+ for event, _ in events_and_contexts:
+ token_id = getattr(event.internal_metadata, "token_id", None)
+ txn_id = getattr(event.internal_metadata, "txn_id", None)
+ if token_id and txn_id:
+ to_insert.append(
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "user_id": event.sender,
+ "token_id": token_id,
+ "txn_id": txn_id,
+ "inserted_ts": self._clock.time_msec(),
+ }
+ )
+
+ if to_insert:
+ self.db_pool.simple_insert_many_txn(
+ txn, table="event_txn_id", values=to_insert,
+ )
+
def _update_current_state_txn(
self,
txn: LoggingTransaction,
@@ -432,12 +457,12 @@ class PersistEventsStore:
# so that async background tasks get told what happened.
sql = """
INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, room_id, type, state_key, null, event_id
+ (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, room_id, type, state_key, null, event_id
FROM current_state_events
WHERE room_id = ?
"""
- txn.execute(sql, (stream_id, room_id))
+ txn.execute(sql, (stream_id, self._instance_name, room_id))
self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id},
@@ -458,8 +483,8 @@ class PersistEventsStore:
#
sql = """
INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
+ (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, ?, (
SELECT event_id FROM current_state_events
WHERE room_id = ? AND type = ? AND state_key = ?
)
@@ -469,6 +494,7 @@ class PersistEventsStore:
(
(
stream_id,
+ self._instance_name,
room_id,
etype,
state_key,
@@ -743,7 +769,7 @@ class PersistEventsStore:
logger.exception("")
raise
- metadata_json = encode_json(event.internal_metadata.get_dict())
+ metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -759,6 +785,7 @@ class PersistEventsStore:
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
+ "instance_name": self._instance_name,
},
)
@@ -797,10 +824,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
- "internal_metadata": encode_json(
+ "internal_metadata": json_encoder.encode(
event.internal_metadata.get_dict()
),
- "json": encode_json(event_dict(event)),
+ "json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts
@@ -1022,9 +1049,7 @@ class PersistEventsStore:
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.prefill(
- (cache_entry[0].event_id,), cache_entry
- )
+ self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
txn.call_after(prefill)
@@ -1241,6 +1266,10 @@ class PersistEventsStore:
)
def _store_retention_policy_for_room_txn(self, txn, event):
+ if not event.is_state():
+ logger.debug("Ignoring non-state m.room.retention event")
+ return
+
if hasattr(event, "content") and (
"min_lifetime" in event.content or "max_lifetime" in event.content
):
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 5e4af2eb51..97b6754846 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -92,6 +92,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
+ self.db_pool.updates.register_background_index_update(
+ "users_have_local_media",
+ index_name="users_have_local_media",
+ table="local_media_repository",
+ columns=["user_id", "created_ts"],
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f95679ebc4..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
import threading
@@ -32,9 +31,13 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
@@ -42,8 +45,9 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached
+from synapse.types import Collection, JsonDict, get_domain_from_id
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -74,6 +78,13 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
+ # Whether to use dedicated DB threads for event fetching. This is only used
+ # if there are multiple DB threads available. When used will lock the DB
+ # thread for periods of time (so unit tests want to disable this when they
+ # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+ # options controlling this.
+ USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -130,11 +141,16 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
- self._get_event_cache = Cache(
- "*getEvent*",
+ if hs.config.run_background_tasks:
+ # We periodically clean out old transaction ID mappings
+ self._clock.looping_call(
+ self._cleanup_old_transaction_ids, 5 * 60 * 1000,
+ )
+
+ self._get_event_cache = LruCache(
+ cache_name="*getEvent*",
keylen=3,
- max_entries=hs.config.caches.event_cache_size,
- apply_cache_factor_from_config=False,
+ max_size=hs.config.caches.event_cache_size,
)
self._event_fetch_lock = threading.Condition()
@@ -510,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
return event_map
+ async def get_stripped_room_state_from_event_context(
+ self,
+ context: EventContext,
+ state_types_to_include: List[EventTypes],
+ membership_user_id: Optional[str] = None,
+ ) -> List[JsonDict]:
+ """
+ Retrieve the stripped state from a room, given an event context to retrieve state
+ from as well as the state types to include. Optionally, include the membership
+ events from a specific user.
+
+ "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+ are included from each state event.
+
+ Args:
+ context: The event context to retrieve state of the room from.
+ state_types_to_include: The type of state events to include.
+ membership_user_id: An optional user ID to include the stripped membership state
+ events of. This is useful when generating the stripped state of a room for
+ invites. We want to send membership events of the inviter, so that the
+ invitee can display the inviter's profile information if the room lacks any.
+
+ Returns:
+ A list of dictionaries, each representing a stripped state event from the room.
+ """
+ current_state_ids = await context.get_current_state_ids()
+
+ # We know this event is not an outlier, so this must be
+ # non-None.
+ assert current_state_ids is not None
+
+ # The state to include
+ state_to_include_ids = [
+ e_id
+ for k, e_id in current_state_ids.items()
+ if k[0] in state_types_to_include
+ or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+ ]
+
+ state_to_include = await self.get_events(state_to_include_ids)
+
+ return [
+ {
+ "type": e.type,
+ "state_key": e.state_key,
+ "content": e.content,
+ "sender": e.sender,
+ }
+ for e in state_to_include.values()
+ ]
+
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
@@ -522,7 +589,11 @@ class EventsWorkerStore(SQLBaseStore):
if not event_list:
single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
self._event_fetch_ongoing -= 1
return
else:
@@ -712,6 +783,7 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
+ original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
event_map[event_id] = original_ev
@@ -728,7 +800,7 @@ class EventsWorkerStore(SQLBaseStore):
event=original_ev, redacted_event=redacted_event
)
- self._get_event_cache.prefill((event_id,), cache_entry)
+ self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
return result_map
@@ -779,6 +851,8 @@ class EventsWorkerStore(SQLBaseStore):
* event_id (str)
+ * stream_ordering (int): stream ordering for this event
+
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
@@ -811,13 +885,15 @@ class EventsWorkerStore(SQLBaseStore):
sql = """\
SELECT
e.event_id,
- e.internal_metadata,
- e.json,
- e.format_version,
+ e.stream_ordering,
+ ej.internal_metadata,
+ ej.json,
+ ej.format_version,
r.room_version,
rej.reason
- FROM event_json as e
- LEFT JOIN rooms r USING (room_id)
+ FROM events AS e
+ JOIN event_json AS ej USING (event_id)
+ LEFT JOIN rooms r ON r.room_id = e.room_id
LEFT JOIN rejections as rej USING (event_id)
WHERE """
@@ -831,11 +907,12 @@ class EventsWorkerStore(SQLBaseStore):
event_id = row[0]
event_dict[event_id] = {
"event_id": event_id,
- "internal_metadata": row[1],
- "json": row[2],
- "format_version": row[3],
- "room_version_id": row[4],
- "rejected_reason": row[5],
+ "stream_ordering": row[1],
+ "internal_metadata": row[2],
+ "json": row[3],
+ "format_version": row[4],
+ "room_version_id": row[5],
+ "rejected_reason": row[6],
"redactions": [],
}
@@ -1017,16 +1094,12 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
- self, last_id: int, current_id: int, limit: int
+ self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
"""Returns new events, for the Events replication stream
@@ -1044,16 +1117,19 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
+ " LEFT JOIN room_memberships USING (event_id)"
+ " LEFT JOIN rejections USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " AND instance_name = ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
return txn.fetchall()
return await self.db_pool.runInteraction(
@@ -1061,7 +1137,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_ex_outlier_stream_rows(
- self, last_id: int, current_id: int
+ self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
@@ -1078,18 +1154,21 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
+ " INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
+ " LEFT JOIN room_memberships USING (event_id)"
+ " LEFT JOIN rejections USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
+ " AND out.instance_name = ?"
" ORDER BY event_stream_ordering ASC"
)
- txn.execute(sql, (last_id, current_id))
+ txn.execute(sql, (last_id, current_id, instance_name))
return txn.fetchall()
return await self.db_pool.runInteraction(
@@ -1102,6 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
+ NOTE: The IDs given here are from replication, and so should be
+ *positive*.
+
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
@@ -1132,10 +1214,11 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " AND instance_name = ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
- txn.execute(sql, (-last_id, -current_id, limit))
+ txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates = [(row[0], row[1:]) for row in txn]
limited = False
@@ -1149,15 +1232,16 @@ class EventsWorkerStore(SQLBaseStore):
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
+ " INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
+ " AND out.instance_name = ?"
" ORDER BY event_stream_ordering DESC"
)
- txn.execute(sql, (-last_id, -upper_bound))
+ txn.execute(sql, (-last_id, -upper_bound, instance_name))
new_event_updates.extend((row[0], row[1:]) for row in txn)
if len(new_event_updates) >= limit:
@@ -1171,7 +1255,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_all_updated_current_state_deltas(
- self, from_token: int, to_token: int, target_row_count: int
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
@@ -1197,9 +1281,10 @@ class EventsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, target_row_count))
+ txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
return txn.fetchall()
def get_deltas_for_stream_id_txn(txn, stream_id):
@@ -1287,3 +1372,77 @@ class EventsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
+
+ async def get_event_id_from_transaction_id(
+ self, room_id: str, user_id: str, token_id: int, txn_id: str
+ ) -> Optional[str]:
+ """Look up if we have already persisted an event for the transaction ID,
+ returning the event ID if so.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="event_txn_id",
+ keyvalues={
+ "room_id": room_id,
+ "user_id": user_id,
+ "token_id": token_id,
+ "txn_id": txn_id,
+ },
+ retcol="event_id",
+ allow_none=True,
+ desc="get_event_id_from_transaction_id",
+ )
+
+ async def get_already_persisted_events(
+ self, events: Iterable[EventBase]
+ ) -> Dict[str, str]:
+ """Look up if we have already persisted an event for the transaction ID,
+ returning a mapping from event ID in the given list to the event ID of
+ an existing event.
+
+ Also checks if there are duplicates in the given events, if there are
+ will map duplicates to the *first* event.
+ """
+
+ mapping = {}
+ txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
+
+ for event in events:
+ token_id = getattr(event.internal_metadata, "token_id", None)
+ txn_id = getattr(event.internal_metadata, "txn_id", None)
+
+ if token_id and txn_id:
+ # Check if this is a duplicate of an event in the given events.
+ existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
+ if existing:
+ mapping[event.event_id] = existing
+ continue
+
+ # Check if this is a duplicate of an event we've already
+ # persisted.
+ existing = await self.get_event_id_from_transaction_id(
+ event.room_id, event.sender, token_id, txn_id
+ )
+ if existing:
+ mapping[event.event_id] = existing
+ txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
+ else:
+ txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
+
+ return mapping
+
+ @wrap_as_background_process("_cleanup_old_transaction_ids")
+ async def _cleanup_old_transaction_ids(self):
+ """Cleans out transaction id mappings older than 24hrs.
+ """
+
+ def _cleanup_old_transaction_ids_txn(txn):
+ sql = """
+ DELETE FROM event_txn_id
+ WHERE inserted_ts < ?
+ """
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ txn.execute(sql, (one_day_ago,))
+
+ return await self.db_pool.runInteraction(
+ "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+ )
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ad43bb05ab..f8f4bb9b3f 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- await self.db_pool.runInteraction(
- "store_server_verify_keys",
- self.db_pool.simple_upsert_many_txn,
+ await self.db_pool.simple_upsert_many(
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
+ desc="store_server_verify_keys",
)
invalidate = self._get_server_verify_key.invalidate
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cc538c5c10..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -93,6 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
+ self.server_name = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
@@ -115,6 +116,109 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media",
)
+ async def get_local_media_by_user_paginate(
+ self, start: int, limit: int, user_id: str
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Get a paginated list of metadata for a local piece of media
+ which an user_id has uploaded
+
+ Args:
+ start: offset in the list
+ limit: maximum amount of media_ids to retrieve
+ user_id: fully-qualified user id
+ Returns:
+ A paginated list of all metadata of user's media,
+ plus the total count of all the user's media
+ """
+
+ def get_local_media_by_user_paginate_txn(txn):
+
+ args = [user_id]
+ sql = """
+ SELECT COUNT(*) as total_media
+ FROM local_media_repository
+ WHERE user_id = ?
+ """
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ "media_id",
+ "media_type",
+ "media_length",
+ "upload_name",
+ "created_ts",
+ "last_access_ts",
+ "quarantined_by",
+ "safe_from_quarantine"
+ FROM local_media_repository
+ WHERE user_id = ?
+ ORDER BY created_ts DESC, media_id DESC
+ LIMIT ? OFFSET ?
+ """
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ media = self.db_pool.cursor_to_dict(txn)
+ return media, count
+
+ return await self.db_pool.runInteraction(
+ "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
+ )
+
+ async def get_local_media_before(
+ self, before_ts: int, size_gt: int, keep_profiles: bool,
+ ) -> Optional[List[str]]:
+
+ # to find files that have never been accessed (last_access_ts IS NULL)
+ # compare with `created_ts`
+ sql = """
+ SELECT media_id
+ FROM local_media_repository AS lmr
+ WHERE
+ ( last_access_ts < ?
+ OR ( created_ts < ? AND last_access_ts IS NULL ) )
+ AND media_length > ?
+ """
+
+ if keep_profiles:
+ sql_keep = """
+ AND (
+ NOT EXISTS
+ (SELECT 1
+ FROM profiles
+ WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
+ AND NOT EXISTS
+ (SELECT 1
+ FROM groups
+ WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
+ AND NOT EXISTS
+ (SELECT 1
+ FROM room_memberships
+ WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
+ AND NOT EXISTS
+ (SELECT 1
+ FROM user_directory
+ WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id)
+ AND NOT EXISTS
+ (SELECT 1
+ FROM room_stats_state
+ WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id)
+ )
+ """.format(
+ media_prefix="mxc://%s/" % (self.server_name,),
+ )
+ sql += sql_keep
+
+ def _get_local_media_before_txn(txn):
+ txn.execute(sql, (before_ts, before_ts, size_gt))
+ return [row[0] for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_local_media_before", _get_local_media_before_txn
+ )
+
async def store_local_media(
self,
media_id,
@@ -348,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
+ async def get_remote_media_thumbnail(
+ self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+ ) -> Optional[Dict[str, Any]]:
+ """Fetch the thumbnail info of given width, height and type.
+ """
+
+ return await self.db_pool.simple_select_one(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": origin,
+ "media_id": media_id,
+ "thumbnail_width": t_width,
+ "thumbnail_height": t_height,
+ "thumbnail_type": t_type,
+ },
+ retcols=(
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ "filesystem_id",
+ ),
+ allow_none=True,
+ desc="get_remote_media_thumbnail",
+ )
+
async def store_remote_media_thumbnail(
self,
origin,
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92099f95ce..ab18cc4d79 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,15 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import calendar
+import logging
+import time
+from typing import Dict
from synapse.metrics import GaugeBucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+logger = logging.getLogger(__name__)
+
# Collect metrics on the number of forward extremities that exist.
_extremities_collecter = GaugeBucketCollector(
"synapse_forward_extremities",
@@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
- def read_forward_extremities():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "read_forward_extremities", self._read_forward_extremities
- )
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
- hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+ @wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
def fetch(txn):
txn.execute(
@@ -137,3 +141,196 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+ async def count_daily_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 24 hours.
+ """
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return await self.db_pool.runInteraction(
+ "count_daily_users", self._count_users, yesterday
+ )
+
+ async def count_monthly_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return await self.db_pool.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
+
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ (count,) = txn.fetchone()
+ return count
+
+ async def count_r30_users(self) -> Dict[str, int]:
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
+ """
+
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] == "unknown":
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ (count,) = txn.fetchone()
+ results["all"] = count
+
+ return results
+
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = time.gmtime()
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+ return today_start * 1000
+
+ @wrap_as_background_process("generate_user_daily_visits")
+ async def generate_user_daily_visits(self) -> None:
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self._clock.time_msec()
+
+ # A note on user_agent. Technically a given device can have multiple
+ # user agents, so we need to decide which one to pick. We could have
+ # handled this in number of ways, but given that we don't care
+ # _that_ much we have gone for MAX(). For more details of the other
+ # options considered see
+ # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent)
+ SELECT u.user_id, u.device_id, ?, MAX(u.user_agent)
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
+ self._last_user_visit_update = today_start
+
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ await self.db_pool.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e93aad33cd..d788dc0fc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@
import logging
from typing import Dict, List
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
@@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._max_mau_value = hs.config.max_mau_value
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
desc="user_last_seen_monthly_active",
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_stats_only = hs.config.mau_stats_only
- self._max_mau_value = hs.config.max_mau_value
-
- # Do not add more reserved users than the total allowable number
- # cur = LoggingTransaction(
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
- def _initialise_reserved_users(self, txn, threepids):
- """Ensures that reserved threepids are accounted for in the MAU table, should
- be called on start up.
-
- Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
- """
-
- # XXX what is this function trying to achieve? It upserts into
- # monthly_active_users for each *registered* reserved mau user, but why?
- #
- # - shouldn't there already be an entry for each reserved user (at least
- # if they have been active recently)?
- #
- # - if it's important that the timestamp is kept up to date, why do we only
- # run this at startup?
-
- for tp in threepids:
- user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
- if user_id:
- is_support = self.is_support_user_txn(txn, user_id)
- if not is_support:
- # We do this manually here to avoid hitting #6791
- self.db_pool.simple_upsert_txn(
- txn,
- table="monthly_active_users",
- keyvalues={"user_id": user_id},
- values={"timestamp": int(self._clock.time_msec())},
- )
- else:
- logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
+ @wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
@@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._mau_stats_only = hs.config.mau_stats_only
+
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
+ def _initialise_reserved_users(self, txn, threepids):
+ """Ensures that reserved threepids are accounted for in the MAU table, should
+ be called on start up.
+
+ Args:
+ txn (cursor):
+ threepids (list[dict]): List of threepid dicts to reserve
+ """
+
+ # XXX what is this function trying to achieve? It upserts into
+ # monthly_active_users for each *registered* reserved mau user, but why?
+ #
+ # - shouldn't there already be an entry for each reserved user (at least
+ # if they have been active recently)?
+ #
+ # - if it's important that the timestamp is kept up to date, why do we only
+ # run this at startup?
+
+ for tp in threepids:
+ user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+ if user_id:
+ is_support = self.is_support_user_txn(txn, user_id)
+ if not is_support:
+ # We do this manually here to avoid hitting #6791
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
+ )
+ else:
+ logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..0e25ca3d7a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
- async def get_profile_displayname(self, user_localpart: str) -> str:
+ async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
@@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
- async def get_profile_avatar_url(self, user_localpart: str) -> str:
+ async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
- self, user_localpart: str, new_displayname: str
+ self, user_localpart: str, new_displayname: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",
@@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
-
-class ProfileStore(ProfileWorkerStore):
- async def add_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
- ) -> None:
- """Ensure we are caching the remote user's profiles.
-
- This should only be called when `is_subscribed_remote_profile_for_user`
- would return true for the user.
- """
- await self.db_pool.simple_upsert(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- values={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="add_remote_profile_cache",
- )
-
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
@@ -138,9 +117,34 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
+ async def is_subscribed_remote_profile_for_user(self, user_id):
+ """Check whether we are interested in a remote user's profile.
+ """
+ res = await self.db_pool.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ return True
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ return True
+
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
- ) -> Dict[str, str]:
+ ) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore):
_get_remote_profile_cache_entries_that_expire_txn,
)
- async def is_subscribed_remote_profile_for_user(self, user_id):
- """Check whether we are interested in a remote user's profile.
- """
- res = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
- if res:
- return True
+class ProfileStore(ProfileWorkerStore):
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
+ """Ensure we are caching the remote user's profiles.
- res = await self.db_pool.simple_select_one_onecol(
- table="group_invites",
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ await self.db_pool.simple_upsert(
+ table="remote_profile_cache",
keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
)
-
- if res:
- return True
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
+ "event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
- "event_json",
"event_push_actions",
"event_search",
"events",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..7997242d90 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
- user_has_pusher = self.get_if_user_has_pusher.cache.get(
+ user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
(user_id,), None, update_metrics=False
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c79ddff680..1e7949a323 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -23,8 +23,8 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util import json_encoder
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -274,6 +274,65 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
+ @cached(num_args=2,)
+ async def get_linearized_receipts_for_all_rooms(
+ self, to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, JsonDict]:
+ """Get receipts for all rooms between two stream_ids, up
+ to a limit of the latest 100 read receipts.
+
+ Args:
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ A dictionary of roomids to a list of receipts.
+ """
+
+ def f(txn):
+ if from_key:
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
+ """
+ txn.execute(sql, [from_key, to_key])
+ else:
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
+ """
+
+ txn.execute(sql, [to_key])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ txn_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_all_rooms", f
+ )
+
+ results = {}
+ for row in txn_results:
+ # We want a single event per room, since we want to batch the
+ # receipts by room, event and type.
+ room_event = results.setdefault(
+ row["room_id"],
+ {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ )
+
+ # The content is of the form:
+ # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
+ event_entry = room_event["content"].setdefault(row["event_id"], {})
+ receipt_type = event_entry.setdefault(row["receipt_type"], {})
+
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
+
+ return results
+
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:
@@ -358,18 +417,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
if receipt_type != "m.read":
return
- # Returns either an ObservableDeferred or the raw result
- res = self.get_users_with_read_receipts_in_room.cache.get(
+ res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
room_id, None, update_metrics=False
)
- # first handle the ObservableDeferred case
- if isinstance(res, ObservableDeferred):
- if res.has_called():
- res = res.get_result()
- else:
- res = None
-
if res and user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..fedb8a6c26 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,32 +14,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool
-from synapse.storage.types import Cursor
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
-class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+ """Result of looking up an access token.
+
+ Attributes:
+ user_id: The user that this token authenticates as
+ is_guest
+ shadow_banned
+ token_id: The ID of the access token looked up
+ device_id: The device associated with the token, if any.
+ valid_until_ms: The timestamp the token expires, if any.
+ token_owner: The "owner" of the token. This is either the same as the
+ user, or a server admin who is logged in as the user.
+ """
+
+ user_id = attr.ib(type=str)
+ is_guest = attr.ib(type=bool, default=False)
+ shadow_banned = attr.ib(type=bool, default=False)
+ token_id = attr.ib(type=Optional[int], default=None)
+ device_id = attr.ib(type=Optional[str], default=None)
+ valid_until_ms = attr.ib(type=Optional[int], default=None)
+ token_owner = attr.ib(type=str)
+
+ # Make the token owner default to the user ID, which is the common case.
+ @token_owner.default
+ def _default_token_owner(self):
+ return self.user_id
+
+
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
- self.clock = hs.get_clock()
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
@@ -48,6 +82,18 @@ class RegistrationWorkerStore(SQLBaseStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
+ self._account_validity = hs.config.account_validity
+ if hs.config.run_background_tasks and self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+ )
+
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
@@ -81,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore):
if not info:
return False
- now = self.clock.time_msec()
+ now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@cached()
- async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+ async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
- None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@@ -225,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_renewal_token_for_user",
)
- async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
+ async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- A list of dictionaries mapping user ID to expiration time (in milliseconds).
+ A list of dictionaries, each with a user ID and expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@@ -246,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(),
+ self._clock.time_msec(),
self.config.account_validity.renew_at,
)
@@ -316,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
- def _query_for_auth(self, txn, token):
- sql = (
- "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
- " access_tokens.device_id, access_tokens.valid_until_ms"
- " FROM users"
- " INNER JOIN access_tokens on users.name = access_tokens.user_id"
- " WHERE token = ?"
- )
+ def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+ sql = """
+ SELECT users.name as user_id,
+ users.is_guest,
+ users.shadow_banned,
+ access_tokens.id as token_id,
+ access_tokens.device_id,
+ access_tokens.valid_until_ms,
+ access_tokens.user_id as token_owner
+ FROM users
+ INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
+ WHERE token = ?
+ """
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
- return rows[0]
+ return TokenLookupResult(**rows[0])
return None
@@ -778,12 +827,111 @@ class RegistrationWorkerStore(SQLBaseStore):
"delete_threepid_session", delete_threepid_session_txn
)
+ @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+ async def cull_expired_threepid_validation_tokens(self) -> None:
+ """Remove threepid validation tokens with expiry dates that have passed"""
+
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self._clock.time_msec(),
+ )
+
+ @wrap_as_background_process("account_validity_set_expiration_dates")
+ async def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db_pool.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ await self.db_pool.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
+
+ async def get_user_pending_deactivation(self) -> Optional[str]:
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
+
+ async def del_user_pending_deactivation(self, user_id: str) -> None:
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ await self.db_pool.simple_delete(
+ "users_pending_deactivation",
+ keyvalues={"user_id": user_id},
+ desc="del_user_pending_deactivation",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self.clock = hs.get_clock()
+ self._clock = hs.get_clock()
self.config = hs.config
self.db_pool.updates.register_background_index_update(
@@ -906,32 +1054,55 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return 1
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
+ """Set the `deactivated` property for the provided user to the provided value.
-class RegistrationStore(RegistrationBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
+ Args:
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
+ """
- self._account_validity = hs.config.account_validity
- self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+ await self.db_pool.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
+ def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
+ )
+ txn.call_after(self.is_guest.invalidate, (user_id,))
+
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
async def add_access_token_to_user(
self,
@@ -939,7 +1110,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
- ) -> None:
+ puppets_user_id: Optional[str] = None,
+ ) -> int:
"""Adds an access token for the given user.
Args:
@@ -949,6 +1121,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
+ Returns:
+ The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
@@ -960,10 +1134,43 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
+ "puppets_user_id": puppets_user_id,
},
desc="add_access_token_to_user",
)
+ return next_id
+
+ def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn, "access_tokens", {"token": token}, "device_id"
+ )
+
+ self.db_pool.simple_update_txn(
+ txn, "access_tokens", {"token": token}, {"device_id": device_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+ return old_device_id
+
+ async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+ """Sets the device ID associated with an access token.
+
+ Args:
+ token: The access token to modify.
+ device_id: The new device ID.
+ Returns:
+ The old device ID associated with the access token.
+ """
+
+ return await self.db_pool.runInteraction(
+ "set_device_for_access_token",
+ self._set_device_for_access_token_txn,
+ token,
+ device_id,
+ )
+
async def register_user(
self,
user_id: str,
@@ -1014,19 +1221,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def _register_user(
self,
txn,
- user_id,
- password_hash,
- was_guest,
- make_guest,
- appservice_id,
- create_profile_with_displayname,
- admin,
- user_type,
- shadow_banned,
+ user_id: str,
+ password_hash: Optional[str],
+ was_guest: bool,
+ make_guest: bool,
+ appservice_id: Optional[str],
+ create_profile_with_displayname: Optional[str],
+ admin: bool,
+ user_type: Optional[str],
+ shadow_banned: bool,
):
user_id_obj = UserID.from_string(user_id)
- now = int(self.clock.time())
+ now = int(self._clock.time())
try:
if was_guest:
@@ -1121,7 +1328,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+ async def user_set_password_hash(
+ self, user_id: str, password_hash: Optional[str]
+ ) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1248,18 +1457,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1271,32 +1468,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_user_pending_deactivation",
)
- async def del_user_pending_deactivation(self, user_id: str) -> None:
- """
- Removes the given user to the table of users who need to be parted from all the
- rooms they're in, effectively marking that user as fully deactivated.
- """
- # XXX: This should be simple_delete_one but we failed to put a unique index on
- # the table, so somehow duplicate entries have ended up in it.
- await self.db_pool.simple_delete(
- "users_pending_deactivation",
- keyvalues={"user_id": user_id},
- desc="del_user_pending_deactivation",
- )
-
- async def get_user_pending_deactivation(self) -> Optional[str]:
- """
- Gets one user from the table of users waiting to be parted from all the rooms
- they're in.
- """
- return await self.db_pool.simple_select_one_onecol(
- "users_pending_deactivation",
- keyvalues={},
- retcol="user_id",
- allow_none=True,
- desc="get_users_pending_deactivation",
- )
-
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
@@ -1379,7 +1550,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
- updatevalues={"validated_at": self.clock.time_msec()},
+ updatevalues={"validated_at": self._clock.time_msec()},
)
return next_link
@@ -1447,106 +1618,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
- )
-
- async def set_user_deactivated_status(
- self, user_id: str, deactivated: bool
- ) -> None:
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id: The ID of the user to set the status for.
- deactivated: The value to set for `deactivated`.
- """
-
- await self.db_pool.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id,
- deactivated,
- )
-
- def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"deactivated": 1 if deactivated else 0},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,)
- )
- txn.call_after(self.is_guest.invalidate, (user_id,))
-
- async def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- await self.db_pool.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self.db_pool.simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id},
- values={"expiration_ts_ms": expiration_ts, "email_sent": False},
- )
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3c7630857f..6b89db15c9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
"count_public_rooms", _count_public_rooms_txn
)
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
+ """
+
+ def f(txn):
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
+
+ return await self.db_pool.runInteraction("get_rooms", f)
+
async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
@@ -857,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
+ async def get_rooms_for_retention_period_in_range(
+ self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+ ) -> Dict[str, dict]:
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms: Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms: Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null: Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.db_pool.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.db_pool.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ return await self.db_pool.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1145,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+ async def maybe_store_room_on_outlier_membership(
+ self, room_id: str, room_version: RoomVersion
+ ):
"""
- When we receive an invite over federation, store the version of the room if we
- don't already know the room version.
+ When we receive an invite or any other event over federation that may relate to a room
+ we are not in, store the version of the room if we don't already know the room version.
"""
await self.db_pool.simple_upsert(
- desc="maybe_store_room_on_invite",
+ desc="maybe_store_room_on_outlier_membership",
table="rooms",
keyvalues={"room_id": room_id},
values={},
@@ -1292,18 +1389,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- async def get_room_count(self) -> int:
- """Retrieve the total number of rooms.
- """
-
- def f(txn):
- sql = "SELECT count(*) FROM rooms"
- txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
-
- return await self.db_pool.runInteraction("get_rooms", f)
-
async def add_event_report(
self,
room_id: str,
@@ -1328,6 +1413,65 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
desc="add_event_report",
)
+ async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+ """Retrieve an event report
+
+ Args:
+ report_id: ID of reported event in database
+ Returns:
+ event_report: json list of information from event report
+ """
+
+ def _get_event_report_txn(txn, report_id):
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.content,
+ events.sender,
+ room_stats_state.canonical_alias,
+ room_stats_state.name,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ JOIN room_stats_state
+ ON room_stats_state.room_id = er.room_id
+ WHERE er.id = ?
+ """
+
+ txn.execute(sql, [report_id])
+ row = txn.fetchone()
+
+ if not row:
+ return None
+
+ event_report = {
+ "id": row[0],
+ "received_ts": row[1],
+ "room_id": row[2],
+ "event_id": row[3],
+ "user_id": row[4],
+ "score": db_to_json(row[5]).get("score"),
+ "reason": db_to_json(row[5]).get("reason"),
+ "sender": row[6],
+ "canonical_alias": row[7],
+ "name": row[8],
+ "event_json": db_to_json(row[9]),
+ }
+
+ return event_report
+
+ return await self.db_pool.runInteraction(
+ "get_event_report", _get_event_report_txn, report_id
+ )
+
async def get_event_reports_paginate(
self,
start: int,
@@ -1385,18 +1529,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
er.room_id,
er.event_id,
er.user_id,
- er.reason,
er.content,
events.sender,
- room_aliases.room_alias,
- event_json.json AS event_json
+ room_stats_state.canonical_alias,
+ room_stats_state.name
FROM event_reports AS er
- LEFT JOIN room_aliases
- ON room_aliases.room_id = er.room_id
- JOIN events
+ LEFT JOIN events
ON events.event_id = er.event_id
- JOIN event_json
- ON event_json.event_id = er.event_id
+ JOIN room_stats_state
+ ON room_stats_state.room_id = er.room_id
{where_clause}
ORDER BY er.received_ts {order}
LIMIT ?
@@ -1407,15 +1548,29 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
args += [limit, start]
txn.execute(sql, args)
- event_reports = self.db_pool.cursor_to_dict(txn)
-
- if count > 0:
- for row in event_reports:
- try:
- row["content"] = db_to_json(row["content"])
- row["event_json"] = db_to_json(row["event_json"])
- except Exception:
- continue
+
+ event_reports = []
+ for row in txn:
+ try:
+ s = db_to_json(row[5]).get("score")
+ r = db_to_json(row[5]).get("reason")
+ except Exception:
+ logger.error("Unable to parse json from event_reports: %s", row[0])
+ continue
+ event_reports.append(
+ {
+ "id": row[0],
+ "received_ts": row[1],
+ "room_id": row[2],
+ "event_id": row[3],
+ "user_id": row[4],
+ "score": s,
+ "reason": r,
+ "sender": row[6],
+ "canonical_alias": row[7],
+ "name": row[8],
+ }
+ )
return event_reports, count
@@ -1446,88 +1601,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
-
- async def get_rooms_for_retention_period_in_range(
- self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, dict]:
- """Retrieves all of the rooms within the given retention range.
-
- Optionally includes the rooms which don't have a retention policy.
-
- Args:
- min_ms: Duration in milliseconds that define the lower limit of
- the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms: Duration in milliseconds that define the upper limit of
- the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null: Whether to include rooms which retention policy is NULL
- in the returned set.
-
- Returns:
- The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
- """
-
- def get_rooms_for_retention_period_in_range_txn(txn):
- range_conditions = []
- args = []
-
- if min_ms is not None:
- range_conditions.append("max_lifetime > ?")
- args.append(min_ms)
-
- if max_ms is not None:
- range_conditions.append("max_lifetime <= ?")
- args.append(max_ms)
-
- # Do a first query which will retrieve the rooms that have a retention policy
- # in their current state.
- sql = """
- SELECT room_id, min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- """
-
- if len(range_conditions):
- sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
- if include_null:
- sql += " OR max_lifetime IS NULL"
-
- txn.execute(sql, args)
-
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
-
- if include_null:
- # If required, do a second query that retrieves all of the rooms we know
- # of so we can handle rooms with no retention policy.
- sql = "SELECT DISTINCT room_id FROM current_state_events"
-
- txn.execute(sql)
-
- rows = self.db_pool.cursor_to_dict(txn)
-
- # If a room isn't already in the dict (i.e. it doesn't have a retention
- # policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
-
- return rooms_dict
-
- rooms = await self.db_pool.runInteraction(
- "get_rooms_for_retention_period_in_range",
- get_rooms_for_retention_period_in_range_txn,
- )
-
- return rooms
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..dcdaf09682 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -14,19 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
- LoggingTransaction,
- SQLBaseStore,
- db_to_json,
- make_in_list_sql_clause,
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
@@ -60,27 +58,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# background update still running?
self._current_state_events_membership_up_to_date = False
- txn = LoggingTransaction(
- db_conn.cursor(),
- name="_check_safe_current_state_events_membership_updated",
- database_engine=self.database_engine,
+ txn = db_conn.cursor(
+ txn_name="_check_safe_current_state_events_membership_updated"
)
self._check_safe_current_state_events_membership_updated_txn(txn)
txn.close()
- if self.hs.config.metrics_flags.known_servers:
+ if (
+ self.hs.config.run_background_tasks
+ and self.hs.config.metrics_flags.known_servers
+ ):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
- run_as_background_process,
- 60 * 1000,
- "_count_known_servers",
- self._count_known_servers,
+ self._count_known_servers, 60 * 1000,
)
self.hs.get_clock().call_later(
- 1000,
- run_as_background_process,
- "_count_known_servers",
- self._count_known_servers,
+ 1000, self._count_known_servers,
)
LaterGauge(
"synapse_federation_known_servers",
@@ -89,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
+ @wrap_as_background_process("_count_known_servers")
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -356,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
+ async def get_local_current_membership_for_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Tuple[Optional[str], Optional[str]]:
+ """Retrieve the current local membership state and event ID for a user in a room.
+
+ Args:
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+
+ Returns:
+ A tuple of (membership_type, event_id). Both will be None if a
+ room_id/user_id pair is not found.
+ """
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_local_current_membership_for_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ results_dict = await self.db_pool.simple_select_one(
+ "local_current_membership",
+ {"room_id": room_id, "user_id": user_id},
+ ("membership", "event_id"),
+ allow_none=True,
+ desc="get_local_current_membership_for_user_in_room",
+ )
+ if not results_dict:
+ return None, None
+
+ return results_dict.get("membership"), results_dict.get("event_id")
+
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
@@ -535,7 +561,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get(
+ prev_res = self._get_joined_users_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
if prev_res and isinstance(prev_res, dict):
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_token, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_origin_server_ts", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
- database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
- % (",".join("?" for _ in chunk),)
- ),
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),),
[as_id] + chunk,
)
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row = list(row)
row[12] = token_to_stream_ordering(row[12])
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_stream_ordering, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_stream_ordering, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search_order", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_fields_sender_url", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
- database_engine.convert_param_style(
- "UPDATE remote_media_cache SET last_access_ts = ?"
- ),
- (int(time.time() * 1000),),
+ "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
)
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
import logging
+from io import StringIO
from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
logger = logging.getLogger(__name__)
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
select_clause,
)
- if isinstance(database_engine, PostgresEngine):
- cur.execute(sql)
- else:
- cur.executescript(sql)
+ execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ?
"""
- sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("%:" + config.server_name,))
cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
index b64926e9c9..3275ae2b20 100644
--- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -20,14 +20,14 @@
*/
-- add new index that includes method to local media
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('local_media_repository_thumbnails_method_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5807, 'local_media_repository_thumbnails_method_idx', '{}');
-- add new index that includes method to remote media
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
-- drop old index
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+ user_id TEXT NOT NULL PRIMARY KEY,
+ device_id TEXT NOT NULL,
+ device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+ user_id TEXT NOT NULL, -- The user this fallback key is for.
+ device_id TEXT NOT NULL, -- The device this fallback key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+ CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
index cade5dcca8..fd733adf13 100644
--- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -28,5 +28,5 @@
-- functionality as the old one. This effectively restarts the background job
-- from the beginning, without running it twice in a row, supporting both
-- upgrade usecases.
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('populate_stats_process_rooms_2', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5812, 'populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+ instance_id SERIAL PRIMARY KEY,
+ instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
new file mode 100644
index 0000000000..b2454121a8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
@@ -0,0 +1,40 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A map of recent events persisted with transaction IDs. Used to deduplicate
+-- send event requests with the same transaction ID.
+--
+-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
+-- used to make the request.
+--
+-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
+-- events or access token we don't want to try and de-duplicate the event.
+CREATE TABLE IF NOT EXISTS event_txn_id (
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ token_id BIGINT NOT NULL,
+ txn_id TEXT NOT NULL,
+ inserted_ts BIGINT NOT NULL,
+ FOREIGN KEY (event_id)
+ REFERENCES events (event_id) ON DELETE CASCADE,
+ FOREIGN KEY (token_id)
+ REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
new file mode 100644
index 0000000000..ad1f481428
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
new file mode 100644
index 0000000000..b0b5dcddce
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Add new column to user_daily_visits to track user agent
+ALTER TABLE user_daily_visits
+ ADD COLUMN user_agent TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
new file mode 100644
index 0000000000..7b84a207fd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
+ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
new file mode 100644
index 0000000000..01ea6eddcf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
@@ -0,0 +1 @@
+DROP TABLE device_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
new file mode 100644
index 0000000000..e1a35be831
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -0,0 +1,2 @@
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5822, 'users_have_local_media', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
new file mode 100644
index 0000000000..75c3915a94
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5823, 'e2e_cross_signing_keys_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
new file mode 100644
index 0000000000..8a39d54aed
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is essentially redundant. The only time it was ever used was when purging
+-- rooms - and Synapse 1.24 will change that.
+
+DROP INDEX IF EXISTS event_json_room_id;
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5beb302be3..0cdb3ec1f7 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,15 +16,18 @@
import logging
from collections import Counter
+from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -59,6 +62,23 @@ TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user
TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
+class UserSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning users
+ with get_users_media_usage_paginate
+
+ MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
+ MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
+ USER_ID = ordered alphabetically by `user_id`.
+ DISPLAYNAME = ordered alphabetically by `displayname`
+ """
+
+ MEDIA_LENGTH = "media_length"
+ MEDIA_COUNT = "media_count"
+ USER_ID = "user_id"
+ DISPLAYNAME = "displayname"
+
+
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -882,3 +902,110 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=pos,
absolute_field_overrides={"joined_rooms": joined_rooms},
)
+
+ async def get_users_media_usage_paginate(
+ self,
+ start: int,
+ limit: int,
+ from_ts: Optional[int] = None,
+ until_ts: Optional[int] = None,
+ order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
+ direction: Optional[str] = "f",
+ search_term: Optional[str] = None,
+ ) -> Tuple[List[JsonDict], Dict[str, int]]:
+ """Function to retrieve a paginated list of users and their uploaded local media
+ (size and number). This will return a json list of users and the
+ total number of users matching the filter criteria.
+
+ Args:
+ start: offset to begin the query from
+ limit: number of rows to retrieve
+ from_ts: request only media that are created later than this timestamp (ms)
+ until_ts: request only media that are created earlier than this timestamp (ms)
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
+ search_term: a string to filter user names by
+ Returns:
+ A list of user dicts and an integer representing the total number of
+ users that exist given this query
+ """
+
+ def get_users_media_usage_paginate_txn(txn):
+ filters = []
+ args = [self.hs.config.server_name]
+
+ if search_term:
+ filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + search_term + "%:%", "%" + search_term + "%"])
+
+ if from_ts:
+ filters.append("created_ts >= ?")
+ args.extend([from_ts])
+ if until_ts:
+ filters.append("created_ts <= ?")
+ args.extend([until_ts])
+
+ # Set ordering
+ if UserSortOrder(order_by) == UserSortOrder.MEDIA_LENGTH:
+ order_by_column = "media_length"
+ elif UserSortOrder(order_by) == UserSortOrder.MEDIA_COUNT:
+ order_by_column = "media_count"
+ elif UserSortOrder(order_by) == UserSortOrder.USER_ID:
+ order_by_column = "lmr.user_id"
+ elif UserSortOrder(order_by) == UserSortOrder.DISPLAYNAME:
+ order_by_column = "displayname"
+ else:
+ raise StoreError(
+ 500, "Incorrect value for order_by provided: %s" % order_by
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql_base = """
+ FROM local_media_repository as lmr
+ LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ?
+ {}
+ GROUP BY lmr.user_id, displayname
+ """.format(
+ where_clause
+ )
+
+ # SQLite does not support SELECT COUNT(*) OVER()
+ sql = """
+ SELECT COUNT(*) FROM (
+ SELECT lmr.user_id
+ {sql_base}
+ ) AS count_user_ids
+ """.format(
+ sql_base=sql_base,
+ )
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ lmr.user_id,
+ displayname,
+ COUNT(lmr.user_id) as media_count,
+ SUM(media_length) as media_length
+ {sql_base}
+ ORDER BY {order_by_column} {order}
+ LIMIT ? OFFSET ?
+ """.format(
+ sql_base=sql_base, order_by_column=order_by_column, order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ users = self.db_pool.cursor_to_dict(txn)
+
+ return users, count
+
+ return await self.db_pool.runInteraction(
+ "get_users_media_usage_paginate_txn", get_users_media_usage_paginate_txn
+ )
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
)
+def _filter_results(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ topological_ordering: int,
+ stream_ordering: int,
+) -> bool:
+ """Returns True if the event persisted by the given instance at the given
+ topological/stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+
+ event_historical_tuple = (
+ topological_ordering,
+ stream_ordering,
+ )
+
+ if lower_token:
+ if lower_token.topological is not None:
+ # If these are historical tokens we compare the `(topological, stream)`
+ # tuples.
+ if event_historical_tuple <= lower_token.as_historical_tuple():
+ return False
+
+ else:
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(
+ instance_name
+ ):
+ return False
+
+ if upper_token:
+ if upper_token.topological is not None:
+ if upper_token.as_historical_tuple() < event_historical_tuple:
+ return False
+ else:
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.get_room_max_stream_ordering())
+ """Get a `RoomStreamToken` that marks the current maximum persisted
+ position of the events stream. Useful to get a token that represents
+ "now".
+
+ The token returned is a "live" token that may have an instance_map
+ component.
+ """
+
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT event_id, instance_name, topological_ordering, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ order,
+ )
+ txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ][:limit]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=False)
if order.lower() == "desc":
ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ ORDER BY e.stream_ordering ASC
+ """
+ txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ]
return rows
@@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> Tuple[int, int, str]:
+ ) -> Optional[Tuple[int, int, str]]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
)
return "t%d-%d" % (topo, token)
- async def get_stream_id_for_event(self, event_id: str) -> int:
- """The stream ID for an event
- Args:
- event_id: The id of the event to look up a stream token for.
- Raises:
- StoreError if the event wasn't in the database.
- Returns:
- A stream ID.
- """
- return await self.db_pool.runInteraction(
- "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
- )
-
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
@@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ if from_token.topological is not None:
+ from_bound = (
+ from_token.as_historical_tuple()
+ ) # type: Tuple[Optional[int], int]
+ elif direction == "b":
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
- args.append(int(limit))
+ # We fetch more events as we'll filter the result set
+ args.append(int(limit) * 2)
select_keywords = "SELECT"
join_clause = ""
@@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords += "DISTINCT"
sql = """
- %(select_keywords)s event_id, topological_ordering, stream_ordering
+ %(select_keywords)s
+ event_id, instance_name,
+ topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
txn.execute(sql, args)
- rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+ # Filter the result set.
+ rows = [
+ _EventDictReturn(event_id, topological_ordering, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ lower_token=to_token if direction == "b" else from_token,
+ upper_token=from_token if direction == "b" else to_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ )
+ ][:limit]
if rows:
topo = rows[-1].topological_ordering
@@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return (events, token)
+ @cached()
+ async def get_id_for_instance(self, instance_name: str) -> int:
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+ """
+
+ def _get_id_for_instance_txn(txn):
+ instance_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ allow_none=True,
+ )
+ if instance_id is not None:
+ return instance_id
+
+ # If we don't have an entry upsert one.
+ #
+ # We could do this before the first check, and rely on the cache for
+ # efficiency, but each UPSERT causes the next ID to increment which
+ # can quickly bloat the size of the generated IDs for new instances.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ values={},
+ )
+
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ )
+
+ return await self.db_pool.runInteraction(
+ "get_id_for_instance", _get_id_for_instance_txn
+ )
+
+ @cached()
+ async def get_name_from_instance_id(self, instance_id: int) -> str:
+ """Get the instance name from an ID previously returned by
+ `get_id_for_instance`.
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="instance_map",
+ keyvalues={"instance_id": instance_id},
+ retcol="instance_name",
+ desc="get_name_from_instance_id",
+ )
+
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 97aed1500e..59207cadd4 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple(
SENTINEL = object()
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+ @wrap_as_background_process("cleanup_transactions")
+ async def _cleanup_transactions(self) -> None:
+ now = self._clock.time_msec()
+ month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+ def _cleanup_transactions_txn(txn):
+ txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+ await self.db_pool.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
+
+
+class TransactionStore(TransactionWorkerStore):
"""A collection of queries for handling PDUs.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
-
self._destination_retry_cache = ExpiringCache(
cache_name="get_destination_retry_timings",
clock=self._clock,
@@ -190,42 +208,56 @@ class TransactionStore(SQLBaseStore):
"""
self._destination_retry_cache.pop(destination, None)
- return await self.db_pool.runInteraction(
- "set_destination_retry_timings",
- self._set_destination_retry_timings,
- destination,
- failure_ts,
- retry_last_ts,
- retry_interval,
- )
+ if self.database_engine.can_native_upsert:
+ return await self.db_pool.runInteraction(
+ "set_destination_retry_timings",
+ self._set_destination_retry_timings_native,
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ db_autocommit=True, # Safe as its a single upsert
+ )
+ else:
+ return await self.db_pool.runInteraction(
+ "set_destination_retry_timings",
+ self._set_destination_retry_timings_emulated,
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ )
- def _set_destination_retry_timings(
+ def _set_destination_retry_timings_native(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
+ assert self.database_engine.can_native_upsert
+
+ # Upsert retry time interval if retry_interval is zero (i.e. we're
+ # resetting it) or greater than the existing retry interval.
+ #
+ # WARNING: This is executed in autocommit, so we shouldn't add any more
+ # SQL calls in here (without being very careful).
+ sql = """
+ INSERT INTO destinations (
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT (destination) DO UPDATE SET
+ failure_ts = EXCLUDED.failure_ts,
+ retry_last_ts = EXCLUDED.retry_last_ts,
+ retry_interval = EXCLUDED.retry_interval
+ WHERE
+ EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval IS NULL
+ OR destinations.retry_interval < EXCLUDED.retry_interval
+ """
- if self.database_engine.can_native_upsert:
- # Upsert retry time interval if retry_interval is zero (i.e. we're
- # resetting it) or greater than the existing retry interval.
-
- sql = """
- INSERT INTO destinations (
- destination, failure_ts, retry_last_ts, retry_interval
- )
- VALUES (?, ?, ?, ?)
- ON CONFLICT (destination) DO UPDATE SET
- failure_ts = EXCLUDED.failure_ts,
- retry_last_ts = EXCLUDED.retry_last_ts,
- retry_interval = EXCLUDED.retry_interval
- WHERE
- EXCLUDED.retry_interval = 0
- OR destinations.retry_interval IS NULL
- OR destinations.retry_interval < EXCLUDED.retry_interval
- """
-
- txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
-
- return
+ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+ def _set_destination_retry_timings_emulated(
+ self, txn, destination, failure_ts, retry_last_ts, retry_interval
+ ):
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
@@ -266,22 +298,6 @@ class TransactionStore(SQLBaseStore):
},
)
- def _start_cleanup_transactions(self):
- return run_as_background_process(
- "cleanup_transactions", self._cleanup_transactions
- )
-
- async def _cleanup_transactions(self) -> None:
- now = self._clock.time_msec()
- month_ago = now - 30 * 24 * 60 * 60 * 1000
-
- def _cleanup_transactions_txn(txn):
- txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
- await self.db_pool.runInteraction(
- "_cleanup_transactions", _cleanup_transactions_txn
- )
-
async def store_destination_rooms_entries(
self, destinations: Iterable[str], room_id: str, stream_ordering: int,
) -> None:
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
)
return [(row["user_agent"], row["ip"]) for row in rows]
-
-class UIAuthStore(UIAuthWorkerStore):
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
iterable=session_ids,
keyvalues={},
)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 5a390ff2f6..d87ceec6da 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id_tuples: iterable of 2-tuple of user IDs.
"""
- def _add_users_who_share_room_txn(txn):
- self.db_pool.simple_upsert_many_txn(
- txn,
- table="users_who_share_private_rooms",
- key_names=["user_id", "other_user_id", "room_id"],
- key_values=[
- (user_id, other_user_id, room_id)
- for user_id, other_user_id in user_id_tuples
- ],
- value_names=(),
- value_values=None,
- )
-
- await self.db_pool.runInteraction(
- "add_users_who_share_room", _add_users_who_share_room_txn
+ await self.db_pool.simple_upsert_many(
+ table="users_who_share_private_rooms",
+ key_names=["user_id", "other_user_id", "room_id"],
+ key_values=[
+ (user_id, other_user_id, room_id)
+ for user_id, other_user_id in user_id_tuples
+ ],
+ value_names=(),
+ value_values=None,
+ desc="add_users_who_share_room",
)
async def add_users_in_public_rooms(
@@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_ids
"""
- def _add_users_in_public_rooms_txn(txn):
-
- self.db_pool.simple_upsert_many_txn(
- txn,
- table="users_in_public_rooms",
- key_names=["user_id", "room_id"],
- key_values=[(user_id, room_id) for user_id in user_ids],
- value_names=(),
- value_values=None,
- )
-
- await self.db_pool.runInteraction(
- "add_users_in_public_rooms", _add_users_in_public_rooms_txn
+ await self.db_pool.simple_upsert_many(
+ table="users_in_public_rooms",
+ key_names=["user_id", "room_id"],
+ key_values=[(user_id, room_id) for user_id in user_ids],
+ value_names=(),
+ value_values=None,
+ desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 72939f3984..70e636b0ba 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -96,7 +96,9 @@ class _EventPeristenceQueue:
Returns:
defer.Deferred: a deferred which will resolve once the events are
- persisted. Runs its callbacks *without* a logcontext.
+ persisted. Runs its callbacks *without* a logcontext. The result
+ is the same as that returned by the callback passed to
+ `handle_queue`.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
@@ -199,7 +201,7 @@ class EventsPersistenceStorage:
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> RoomStreamToken:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""
Write events to the database
Args:
@@ -209,7 +211,11 @@ class EventsPersistenceStorage:
which might update the current state etc.
Returns:
- the stream ordering of the latest persisted event
+ List of events persisted, the current position room stream position.
+ The list of events persisted may not be the same as those passed in
+ if they were deduplicated due to an event already existing that
+ matched the transcation ID; the existing event is returned in such
+ a case.
"""
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
@@ -225,19 +231,41 @@ class EventsPersistenceStorage:
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- await make_deferred_yieldable(
+ # Each deferred returns a map from event ID to existing event ID if the
+ # event was deduplicated. (The dict may also include other entries if
+ # the event was persisted in a batch with other events).
+ #
+ # Since we use `defer.gatherResults` we need to merge the returned list
+ # of dicts into one.
+ ret_vals = await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
+ replaced_events = {}
+ for d in ret_vals:
+ replaced_events.update(d)
+
+ events = []
+ for event, _ in events_and_contexts:
+ existing_event_id = replaced_events.get(event.event_id)
+ if existing_event_id:
+ events.append(await self.main_store.get_event(existing_event_id))
+ else:
+ events.append(event)
- return self.main_store.get_room_max_token()
+ return (
+ events,
+ self.main_store.get_room_max_token(),
+ )
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
+ ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
"""
Returns:
- The stream ordering of `event`, and the stream ordering of the
- latest persisted event
+ The event, stream ordering of `event`, and the stream ordering of the
+ latest persisted event. The returned event may not match the given
+ event if it was deduplicated due to an existing event matching the
+ transaction ID.
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@@ -245,17 +273,33 @@ class EventsPersistenceStorage:
self._maybe_start_persisting(event.room_id)
- await make_deferred_yieldable(deferred)
+ # The deferred returns a map from event ID to existing event ID if the
+ # event was deduplicated. (The dict may also include other entries if
+ # the event was persisted in a batch with other events.)
+ replaced_events = await make_deferred_yieldable(deferred)
+ replaced_event = replaced_events.get(event.event_id)
+ if replaced_event:
+ event = await self.main_store.get_event(replaced_event)
event_stream_id = event.internal_metadata.stream_ordering
+ # stream ordering should have been assigned by now
+ assert event_stream_id
pos = PersistedEventPosition(self._instance_name, event_stream_id)
- return pos, self.main_store.get_room_max_token()
+ return event, pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
+ """Pokes the `_event_persist_queue` to start handling new items in the
+ queue, if not already in progress.
+
+ Causes the deferreds returned by `add_to_queue` to resolve with: a
+ dictionary of event ID to event ID we didn't persist as we already had
+ another event persisted with the same TXN ID.
+ """
+
async def persisting_queue(item):
with Measure(self._clock, "persist_events"):
- await self._persist_events(
+ return await self._persist_events(
item.events_and_contexts, backfilled=item.backfilled
)
@@ -265,12 +309,38 @@ class EventsPersistenceStorage:
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ):
+ ) -> Dict[str, str]:
"""Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
+
+ Returns:
+ A dictionary of event ID to event ID we didn't persist as we already
+ had another event persisted with the same TXN ID.
"""
+ replaced_events = {} # type: Dict[str, str]
if not events_and_contexts:
- return
+ return replaced_events
+
+ # Check if any of the events have a transaction ID that has already been
+ # persisted, and if so we don't persist it again.
+ #
+ # We should have checked this a long time before we get here, but it's
+ # possible that different send event requests race in such a way that
+ # they both pass the earlier checks. Checking here isn't racey as we can
+ # have only one `_persist_events` per room being called at a time.
+ replaced_events = await self.main_store.get_already_persisted_events(
+ (event for event, _ in events_and_contexts)
+ )
+
+ if replaced_events:
+ events_and_contexts = [
+ (e, ctx)
+ for e, ctx in events_and_contexts
+ if e.event_id not in replaced_events
+ ]
+
+ if not events_and_contexts:
+ return replaced_events
chunks = [
events_and_contexts[x : x + 100]
@@ -439,6 +509,8 @@ class EventsPersistenceStorage:
await self._handle_potentially_left_users(potentially_left_users)
+ return replaced_events
+
async def _calculate_new_extremities(
self,
room_id: str,
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 4957e77f4c..459754feab 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import imp
import logging
import os
@@ -24,9 +23,10 @@ from typing import Optional, TextIO
import attr
from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
def prepare_database(
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ["main", "state"],
@@ -89,7 +89,7 @@ def prepare_database(
"""
try:
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="prepare_database")
# sqlite does not automatically start transactions for DDL / SELECT statements,
# so we start one before running anything. This ensures that any upgrades
@@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases):
executescript(cur, entry.absolute_path)
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
- ),
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
(max_current_ver, False),
)
@@ -486,17 +484,13 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
- ),
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
(v, relative_path),
)
cur.execute("DELETE FROM schema_version")
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
- ),
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
(v, True),
)
@@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
schemas to be applied
"""
cur.execute(
- database_engine.convert_param_style(
- "SELECT file FROM applied_module_schemas WHERE module_name = ?"
- ),
- (modname,),
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
)
applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
@@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
- ),
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
(modname, name),
)
@@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine):
if current_version:
txn.execute(
- database_engine.convert_param_style(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?"
- ),
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
)
applied_deltas = [d for d, in txn]
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..9cadcba18f 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Tuple
+from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing_extensions import Protocol
@@ -61,3 +61,9 @@ class Connection(Protocol):
def rollback(self, *args, **kwargs) -> None:
...
+
+ def __enter__(self) -> "Connection":
+ ...
+
+ def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+ ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index eccd2d5b7b..02d71302ea 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -55,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1):
"""
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_id")
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
@@ -270,7 +270,7 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
):
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream.
if self._writers:
@@ -284,15 +284,12 @@ class MultiWriterIdGenerator:
stream_name = ?
AND instance_name != ALL(?)
"""
- sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (self._stream_name, self._writers))
sql = """
SELECT instance_name, stream_id FROM stream_positions
WHERE stream_name = ?
"""
- sql = self._db.engine.convert_param_style(sql)
-
cur.execute(sql, (self._stream_name,))
self._current_positions = {
@@ -341,7 +338,6 @@ class MultiWriterIdGenerator:
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
- sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id
@@ -422,7 +418,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
- new_cur = None
+ new_cur = None # type: Optional[int]
if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the
@@ -528,6 +524,16 @@ class MultiWriterIdGenerator:
heapq.heappush(self._known_persisted_positions, new_id)
+ # If we're a writer and we don't have any active writes we update our
+ # current position to the latest position seen. This allows the instance
+ # to report a recent position when asked, rather than a potentially old
+ # one (if this instance hasn't written anything for a while).
+ our_current_position = self._current_positions.get(self._instance_name)
+ if our_current_position and not self._unfinished_ids:
+ self._current_positions[self._instance_name] = max(
+ our_current_position, new_id
+ )
+
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 2dd95e2709..4386b6101e 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -17,6 +17,7 @@ import logging
import threading
from typing import Callable, List, Optional
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
@@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
@@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator):
return [i for (i,) in txn]
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
):
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
@@ -117,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator):
if max_stream_id > last_value:
logger.warning(
"Postgres sequence %s is behind table %s: %d < %d",
+ self._sequence_name,
+ table,
last_value,
max_stream_id,
)
|