diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
db_name = database_config.name
engine = create_engine(database_config.config)
- with make_conn(database_config, engine) as db_conn:
+ with make_conn(database_config, engine, "startup") as db_conn:
logger.info("[database config %r]: Checking database server", db_name)
engine.check_database(db_conn)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0cb12f4c61..9b16f45f3e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import calendar
import logging
-import time
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
@@ -268,9 +266,6 @@ class DataStore(
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
- # Used in _generate_user_daily_visits to keep track of progress
- self._last_user_visit_update = self._get_start_of_day()
-
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -289,7 +284,6 @@ class DataStore(
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
- sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
@@ -301,192 +295,6 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- async def count_daily_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 24 hours.
- """
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return await self.db_pool.runInteraction(
- "count_daily_users", self._count_users, yesterday
- )
-
- async def count_monthly_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 30 days.
- Note this method is intended for phonehome metrics only and is different
- from the mau figure in synapse.storage.monthly_active_users which,
- amongst other things, includes a 3 day grace period before a user counts.
- """
- thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return await self.db_pool.runInteraction(
- "count_monthly_users", self._count_users, thirty_days_ago
- )
-
- def _count_users(self, txn, time_from):
- """
- Returns number of users seen in the past time_from period
- """
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
- txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
- return count
-
- async def count_r30_users(self) -> Dict[str, int]:
- """
- Counts the number of 30 day retained users, defined as:-
- * Users who have created their accounts more than 30 days ago
- * Where last seen at most 30 days ago
- * Where account creation and last_seen are > 30 days apart
-
- Returns:
- A mapping of counts globally as well as broken out by platform.
- """
-
- def _count_r30_users(txn):
- thirty_days_in_secs = 86400 * 30
- now = int(self._clock.time())
- thirty_days_ago_in_secs = now - thirty_days_in_secs
-
- sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
- SELECT
- users.name, platform, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen,
- CASE
- WHEN user_agent LIKE '%%Android%%' THEN 'android'
- WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
- WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
- WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
- WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
- ELSE 'unknown'
- END
- AS platform
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND users.appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, platform, users.creation_ts
- ) u GROUP BY platform
- """
-
- results = {}
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- for row in txn:
- if row[0] == "unknown":
- pass
- results[row[0]] = row[1]
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT users.name, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, users.creation_ts
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- (count,) = txn.fetchone()
- results["all"] = count
-
- return results
-
- return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
- def _get_start_of_day(self):
- """
- Returns millisecond unixtime for start of UTC day.
- """
- now = time.gmtime()
- today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
- return today_start * 1000
-
- async def generate_user_daily_visits(self) -> None:
- """
- Generates daily visit data for use in cohort/ retention analysis
- """
-
- def _generate_user_daily_visits(txn):
- logger.info("Calling _generate_user_daily_visits")
- today_start = self._get_start_of_day()
- a_day_in_milliseconds = 24 * 60 * 60 * 1000
- now = self.clock.time_msec()
-
- sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
- FROM user_ips AS u
- LEFT JOIN (
- SELECT user_id, device_id, timestamp FROM user_daily_visits
- WHERE timestamp = ?
- ) udv
- ON u.user_id = udv.user_id AND u.device_id=udv.device_id
- INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
- AND udv.timestamp IS NULL AND users.is_guest=0
- AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
- """
-
- # This means that the day has rolled over but there could still
- # be entries from the previous day. There is an edge case
- # where if the user logs in at 23:59 and overwrites their
- # last_seen at 00:01 then they will not be counted in the
- # previous day's stats - it is important that the query is run
- # often to minimise this case.
- if today_start > self._last_user_visit_update:
- yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(
- sql,
- (
- yesterday_start,
- yesterday_start,
- self._last_user_visit_update,
- today_start,
- ),
- )
- self._last_user_visit_update = today_start
-
- txn.execute(
- sql, (today_start, today_start, self._last_user_visit_update, now)
- )
- # Update _last_user_visit_update to now. The reason to do this
- # rather just clamping to the beginning of the day is to limit
- # the size of the join - meaning that the query can be run more
- # frequently
- self._last_user_visit_update = now
-
- await self.db_pool.runInteraction(
- "generate_user_daily_visits", _generate_user_daily_visits
- )
-
async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef81d73573..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
import logging
from typing import Dict, List, Optional, Tuple
+from synapse.api.constants import AccountDataTypes
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list",
+ AccountDataTypes.IGNORED_USER_LIST,
ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
return False
- return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ try:
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ except TypeError:
+ # The type of the ignored_users field is invalid.
+ return False
class AccountDataStore(AccountDataWorkerStore):
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 4bb2b9c28c..849bd5ba7a 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -17,7 +17,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.events.utils import prune_event_dict
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
@@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- def _censor_redactions():
- return run_as_background_process(
- "_censor_redactions", self._censor_redactions
- )
-
- if self.hs.config.redaction_retention_period is not None:
- hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+ if (
+ hs.config.run_background_tasks
+ and self.hs.config.redaction_retention_period is not None
+ ):
+ hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
+ @wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self):
"""Censors all redactions older than the configured period that haven't
been censored yet.
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 2611b92281..c30746c886 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -351,7 +351,63 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
+
+ if hs.config.run_background_tasks and self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.db_pool.updates.has_completed_background_update(
+ "devices_last_seen"
+ ):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
+ )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.db_pool.runInteraction(
+ "_prune_old_user_ips", _prune_old_user_ips_txn
+ )
+
+
+class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache(
@@ -360,8 +416,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
- self.user_ips_max_age = hs.config.user_ips_max_age
-
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
- if self.user_ips_max_age:
- self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]
-
- @wrap_as_background_process("prune_old_user_ips")
- async def _prune_old_user_ips(self):
- """Removes entries in user IPs older than the configured period.
- """
-
- if self.user_ips_max_age is None:
- # Nothing to do
- return
-
- if not await self.db_pool.updates.has_completed_background_update(
- "devices_last_seen"
- ):
- # Only start pruning if we have finished populating the devices
- # last seen info.
- return
-
- # We do a slightly funky SQL delete to ensure we don't try and delete
- # too much at once (as the table may be very large from before we
- # started pruning).
- #
- # This works by finding the max last_seen that is less than the given
- # time, but has no more than N rows before it, deleting all rows with
- # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
- # returns exactly one row).
- sql = """
- DELETE FROM user_ips
- WHERE last_seen <= (
- SELECT COALESCE(MAX(last_seen), -1)
- FROM (
- SELECT last_seen FROM user_ips
- WHERE last_seen <= ?
- ORDER BY last_seen ASC
- LIMIT 5000
- ) AS u
- )
- """
-
- timestamp = self.clock.time_msec() - self.user_ips_max_age
-
- def _prune_old_user_ips_txn(txn):
- txn.execute(sql, (timestamp,))
-
- await self.db_pool.runInteraction(
- "_prune_old_user_ips", _prune_old_user_ips_txn
- )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..88fd97e1df 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
trace,
whitelisted_homeserver,
)
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -33,7 +33,7 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -48,6 +48,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(
+ self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+ )
+
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -698,6 +706,172 @@ class DeviceWorkerStore(SQLBaseStore):
_mark_remote_user_device_list_as_unsubscribed_txn,
)
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ # FIXME: make sure device ID still exists in devices table
+ row = await self.db_pool.simple_select_one(
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcols=["device_id", "device_data"],
+ allow_none=True,
+ )
+ return (
+ (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+ )
+
+ def _store_dehydrated_device_txn(
+ self, txn, user_id: str, device_id: str, device_data: str
+ ) -> Optional[str]:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcol="device_id",
+ allow_none=True,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ values={"device_id": device_id, "device_data": device_data},
+ )
+ return old_device_id
+
+ async def store_dehydrated_device(
+ self, user_id: str, device_id: str, device_data: JsonDict
+ ) -> Optional[str]:
+ """Store a dehydrated device for a user.
+
+ Args:
+ user_id: the user that we are storing the device for
+ device_id: the ID of the dehydrated device
+ device_data: the dehydrated device information
+ Returns:
+ device id of the user's previous dehydrated device, if any
+ """
+ return await self.db_pool.runInteraction(
+ "store_dehydrated_device_txn",
+ self._store_dehydrated_device_txn,
+ user_id,
+ device_id,
+ json_encoder.encode(device_data),
+ )
+
+ async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+ """Remove a dehydrated device.
+
+ Args:
+ user_id: the user that the dehydrated device belongs to
+ device_id: the ID of the dehydrated device
+ """
+ count = await self.db_pool.simple_delete(
+ "dehydrated_devices",
+ {"user_id": user_id, "device_id": device_id},
+ desc="remove_dehydrated_device",
+ )
+ return count >= 1
+
+ @wrap_as_background_process("prune_old_outbound_device_pokes")
+ async def _prune_old_outbound_device_pokes(
+ self, prune_age: int = 24 * 60 * 60 * 1000
+ ) -> None:
+ """Delete old entries out of the device_lists_outbound_pokes to ensure
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
+ """
+ yesterday = self._clock.time_msec() - prune_age
+
+ def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
+ select_sql = """
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
+ GROUP BY destination, user_id
+ HAVING min(ts) < ? AND count(*) > 1
+ """
+
+ txn.execute(select_sql, (yesterday,))
+ rows = txn.fetchall()
+
+ if not rows:
+ return
+
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
+ delete_sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
+ """
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
+
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+ logger.info("Pruned %d device list outbound pokes", count)
+
+ await self.db_pool.runInteraction(
+ "_prune_old_outbound_device_pokes", _prune_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -834,10 +1008,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
name="device_id_exists", keylen=2, max_entries=10000
)
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: str
+ self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -955,7 +1127,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def update_remote_device_list_cache_entry(
- self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: str
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
@@ -983,7 +1155,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
content: JsonDict,
- stream_id: int,
+ stream_id: str,
) -> None:
if content.get("deleted"):
self.db_pool.simple_delete_txn(
@@ -1193,95 +1365,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids
],
)
-
- def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
- """Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers.
-
- Normally, we try to send device updates as a delta since a previous known point:
- this is done by setting the prev_id in the m.device_list_update EDU. However,
- for that to work, we have to have a complete record of each change to
- each device, which can add up to quite a lot of data.
-
- An alternative mechanism is that, if the remote server sees that it has missed
- an entry in the stream_id sequence for a given user, it will request a full
- list of that user's devices. Hence, we can reduce the amount of data we have to
- store (and transmit in some future transaction), by clearing almost everything
- for a given destination out of the database, and having the remote server
- resync.
-
- All we need to do is make sure we keep at least one row for each
- (user, destination) pair, to remind us to send a m.device_list_update EDU for
- that user when the destination comes back. It doesn't matter which device
- we keep.
- """
- yesterday = self._clock.time_msec() - prune_age
-
- def _prune_txn(txn):
- # look for (user, destination) pairs which have an update older than
- # the cutoff.
- #
- # For each pair, we also need to know the most recent stream_id, and
- # an arbitrary device_id at that stream_id.
- select_sql = """
- SELECT
- dlop1.destination,
- dlop1.user_id,
- MAX(dlop1.stream_id) AS stream_id,
- (SELECT MIN(dlop2.device_id) AS device_id FROM
- device_lists_outbound_pokes dlop2
- WHERE dlop2.destination = dlop1.destination AND
- dlop2.user_id=dlop1.user_id AND
- dlop2.stream_id=MAX(dlop1.stream_id)
- )
- FROM device_lists_outbound_pokes dlop1
- GROUP BY destination, user_id
- HAVING min(ts) < ? AND count(*) > 1
- """
-
- txn.execute(select_sql, (yesterday,))
- rows = txn.fetchall()
-
- if not rows:
- return
-
- logger.info(
- "Pruning old outbound device list updates for %i users/destinations: %s",
- len(rows),
- shortstr((row[0], row[1]) for row in rows),
- )
-
- # we want to keep the update with the highest stream_id for each user.
- #
- # there might be more than one update (with different device_ids) with the
- # same stream_id, so we also delete all but one rows with the max stream id.
- delete_sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND (
- stream_id < ? OR
- (stream_id = ? AND device_id != ?)
- )
- """
- count = 0
- for (destination, user_id, stream_id, device_id) in rows:
- txn.execute(
- delete_sql, (destination, user_id, stream_id, stream_id, device_id)
- )
- count += txn.rowcount
-
- # Since we've deleted unsent deltas, we need to remove the entry
- # of last successful sent so that the prev_ids are correctly set.
- sql = """
- DELETE FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
- logger.info("Pruned %d device list outbound pokes", count)
-
- return run_as_background_process(
- "prune_old_outbound_device_pokes",
- self.db_pool.runInteraction,
- "_prune_old_outbound_device_pokes",
- _prune_txn,
- )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..4415909414 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -367,6 +367,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+ async def set_e2e_fallback_keys(
+ self, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
+ # fallback_keys will usually only have one item in it, so using a for
+ # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+ # FIXME: make sure that only one key per algorithm is uploaded
+ for key_id, fallback_key in fallback_keys.items():
+ algorithm, key_id = key_id.split(":", 1)
+ await self.db_pool.simple_upsert(
+ "e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ desc="set_e2e_fallback_key",
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ @cached(max_entries=10000)
+ async def get_e2e_unused_fallback_key_types(
+ self, user_id: str, device_id: str
+ ) -> List[str]:
+ """Returns the fallback key types that have an unused key.
+
+ Args:
+ user_id: the user whose keys are being queried
+ device_id: the device whose keys are being queried
+
+ Returns:
+ a list of key types
+ """
+ return await self.db_pool.simple_select_onecol(
+ "e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+ retcol="algorithm",
+ desc="get_e2e_unused_fallback_key_types",
+ )
+
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
@@ -701,15 +756,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
+ fallback_sql = (
+ "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
result = {}
delete = []
+ used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
+ otk_row = txn.fetchone()
+ if otk_row is not None:
+ key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
+ else:
+ # no one-time key available, so see if there's a fallback
+ # key
+ txn.execute(fallback_sql, (user_id, device_id, algorithm))
+ fallback_row = txn.fetchone()
+ if fallback_row is not None:
+ key_id, key_json, used = fallback_row
+ device_result[algorithm + ":" + key_id] = key_json
+ if not used:
+ used_fallbacks.append(
+ (user_id, device_id, algorithm, key_id)
+ )
+
+ # drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +803,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ # mark fallback keys as used
+ for user_id, device_id, algorithm, key_id in used_fallbacks:
+ self.db_pool.simple_update_txn(
+ txn,
+ "e2e_fallback_keys_json",
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ {"used": True},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
return result
return await self.db_pool.runInteraction(
@@ -754,6 +848,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..a6279a6c13 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,7 +19,7 @@ from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
from synapse.events import EventBase
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -32,6 +32,14 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ hs.get_clock().looping_call(
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ )
+
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -586,6 +594,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row["event_id"] for row in rows]
+ @wrap_as_background_process("delete_old_forward_extrem_cache")
+ async def _delete_old_forward_extrem_cache(self) -> None:
+ def _delete_old_forward_extrem_cache_txn(txn):
+ # Delete entries older than a month, while making sure we don't delete
+ # the only entries for a room.
+ sql = """
+ DELETE FROM stream_ordering_to_exterm
+ WHERE
+ room_id IN (
+ SELECT room_id
+ FROM stream_ordering_to_exterm
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
+ """
+ txn.execute(
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+ )
+
+ await self.db_pool.runInteraction(
+ "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+ )
+
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
@@ -606,34 +636,6 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
- hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
- )
-
- def _delete_old_forward_extrem_cache(self):
- def _delete_old_forward_extrem_cache_txn(txn):
- # Delete entries older than a month, while making sure we don't delete
- # the only entries for a room.
- sql = """
- DELETE FROM stream_ordering_to_exterm
- WHERE
- room_id IN (
- SELECT room_id
- FROM stream_ordering_to_exterm
- WHERE stream_ordering > ?
- ) AND stream_ordering < ?
- """
- txn.execute(
- sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
- )
-
- return run_as_background_process(
- "delete_old_forward_extrem_cache",
- self.db_pool.runInteraction,
- "_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn,
- )
-
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..2e56dfaf31 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Dict, List, Optional, Tuple, Union
import attr
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.stream_ordering_month_ago = None
self.stream_ordering_day_ago = None
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- )
+ cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
cur.close()
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
+
self._rotate_delay = 3
self._rotate_count = 10000
+ self._doing_notif_rotation = False
+ if hs.config.run_background_tasks:
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
@@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Error removing push actions after event persistence failure"
)
- def _find_stream_orderings_for_times(self):
- return run_as_background_process(
- "event_push_action_stream_orderings",
- self.db_pool.runInteraction,
+ @wrap_as_background_process("event_push_action_stream_orderings")
+ async def _find_stream_orderings_for_times(self) -> None:
+ await self.db_pool.runInteraction(
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
- def _find_stream_orderings_for_times_txn(self, txn):
+ def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
logger.info("Searching for stream ordering 1 month ago")
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
@@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
return result[0] if result else None
-
-class EventPushActionsStore(EventPushActionsWorkerStore):
- EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self.db_pool.updates.register_background_index_update(
- self.EPA_HIGHLIGHT_INDEX,
- index_name="event_push_actions_u_highlight",
- table="event_push_actions",
- columns=["user_id", "stream_ordering"],
- )
-
- self.db_pool.updates.register_background_index_update(
- "event_push_actions_highlights_index",
- index_name="event_push_actions_highlights_index",
- table="event_push_actions",
- columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1",
- )
-
- self._doing_notif_rotation = False
- self._rotate_notif_loop = self._clock.looping_call(
- self._start_rotate_notifs, 30 * 60 * 1000
- )
-
- async def get_push_actions_for_user(
- self, user_id, before=None, limit=50, only_highlight=False
- ):
- def f(txn):
- before_clause = ""
- if before:
- before_clause = "AND epa.stream_ordering < ?"
- args = [user_id, before, limit]
- else:
- args = [user_id, limit]
-
- if only_highlight:
- if len(before_clause) > 0:
- before_clause += " "
- before_clause += "AND epa.highlight = 1"
-
- # NB. This assumes event_ids are globally unique since
- # it makes the query easier to index
- sql = (
- "SELECT epa.event_id, epa.room_id,"
- " epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
- " FROM event_push_actions epa, events e"
- " WHERE epa.event_id = e.event_id"
- " AND epa.user_id = ? %s"
- " AND epa.notif = 1"
- " ORDER BY epa.stream_ordering DESC"
- " LIMIT ?" % (before_clause,)
- )
- txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
-
- push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
- for pa in push_actions:
- pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- return push_actions
-
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
- def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
-
- Args:
- txn: The transcation
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id),
- )
-
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
- )
-
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
-
- def _start_rotate_notifs(self):
- return run_as_background_process("rotate_notifs", self._rotate_notifs)
-
+ @wrap_as_background_process("rotate_notifs")
async def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
@@ -958,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
+class EventPushActionsStore(EventPushActionsWorkerStore):
+ EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ self.EPA_HIGHLIGHT_INDEX,
+ index_name="event_push_actions_u_highlight",
+ table="event_push_actions",
+ columns=["user_id", "stream_ordering"],
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1",
+ )
+
+ async def get_push_actions_for_user(
+ self, user_id, before=None, limit=50, only_highlight=False
+ ):
+ def f(txn):
+ before_clause = ""
+ if before:
+ before_clause = "AND epa.stream_ordering < ?"
+ args = [user_id, before, limit]
+ else:
+ args = [user_id, limit]
+
+ if only_highlight:
+ if len(before_clause) > 0:
+ before_clause += " "
+ before_clause += "AND epa.highlight = 1"
+
+ # NB. This assumes event_ids are globally unique since
+ # it makes the query easier to index
+ sql = (
+ "SELECT epa.event_id, epa.room_id,"
+ " epa.stream_ordering, epa.topological_ordering,"
+ " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
+ " FROM event_push_actions epa, events e"
+ " WHERE epa.event_id = e.event_id"
+ " AND epa.user_id = ? %s"
+ " AND epa.notif = 1"
+ " ORDER BY epa.stream_ordering DESC"
+ " LIMIT ?" % (before_clause,)
+ )
+ txn.execute(sql, args)
+ return self.db_pool.cursor_to_dict(txn)
+
+ push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
+ for pa in push_actions:
+ pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
+ return push_actions
+
+ async def get_latest_push_action_stream_ordering(self):
+ def f(txn):
+ txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
+ return txn.fetchone()
+
+ result = await self.db_pool.runInteraction(
+ "get_latest_push_action_stream_ordering", f
+ )
+ return result[0] or 0
+
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
+ """
+ Purges old push actions for a user and room before a given
+ stream_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
+ Args:
+ txn: The transcation
+ room_id: Room ID to delete from
+ user_id: user ID to delete for
+ stream_ordering: The lowest stream ordering which will
+ not be deleted.
+ """
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id, user_id),
+ )
+
+ # We need to join on the events table to get the received_ts for
+ # event_push_actions and sqlite won't let us use a join in a delete so
+ # we can't just delete where received_ts < x. Furthermore we can
+ # only identify event_push_actions by a tuple of room_id, event_id
+ # we we can't use a subquery.
+ # Instead, we look up the stream ordering for the last event in that
+ # room received before the threshold time and delete event_push_actions
+ # in the room with a stream_odering before that.
+ txn.execute(
+ "DELETE FROM event_push_actions "
+ " WHERE user_id = ? AND room_id = ? AND "
+ " stream_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ )
+
+ txn.execute(
+ """
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
+
def _action_has_highlight(actions):
for action in actions:
try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 78e645592f..fdb17745f6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -331,6 +331,10 @@ class PersistEventsStore:
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+ # stream orderings should have been assigned by now
+ assert min_stream_order
+ assert max_stream_order
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -357,6 +361,8 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
+ self._persist_transaction_ids_txn(txn, events_and_contexts)
+
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
@@ -401,6 +407,35 @@ class PersistEventsStore:
# room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ def _persist_transaction_ids_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ):
+ """Persist the mapping from transaction IDs to event IDs (if defined).
+ """
+
+ to_insert = []
+ for event, _ in events_and_contexts:
+ token_id = getattr(event.internal_metadata, "token_id", None)
+ txn_id = getattr(event.internal_metadata, "txn_id", None)
+ if token_id and txn_id:
+ to_insert.append(
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "user_id": event.sender,
+ "token_id": token_id,
+ "txn_id": txn_id,
+ "inserted_ts": self._clock.time_msec(),
+ }
+ )
+
+ if to_insert:
+ self.db_pool.simple_insert_many_txn(
+ txn, table="event_txn_id", values=to_insert,
+ )
+
def _update_current_state_txn(
self,
txn: LoggingTransaction,
@@ -422,12 +457,12 @@ class PersistEventsStore:
# so that async background tasks get told what happened.
sql = """
INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, room_id, type, state_key, null, event_id
+ (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, room_id, type, state_key, null, event_id
FROM current_state_events
WHERE room_id = ?
"""
- txn.execute(sql, (stream_id, room_id))
+ txn.execute(sql, (stream_id, self._instance_name, room_id))
self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id},
@@ -448,8 +483,8 @@ class PersistEventsStore:
#
sql = """
INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
+ (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, ?, (
SELECT event_id FROM current_state_events
WHERE room_id = ? AND type = ? AND state_key = ?
)
@@ -459,6 +494,7 @@ class PersistEventsStore:
(
(
stream_id,
+ self._instance_name,
room_id,
etype,
state_key,
@@ -751,6 +787,7 @@ class PersistEventsStore:
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
+ "instance_name": self._instance_name,
},
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f95679ebc4..3ec4d1d9c2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
import threading
@@ -74,6 +73,13 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
+ # Whether to use dedicated DB threads for event fetching. This is only used
+ # if there are multiple DB threads available. When used will lock the DB
+ # thread for periods of time (so unit tests want to disable this when they
+ # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+ # options controlling this.
+ USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -130,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
+ if not hs.config.worker.worker_app:
+ # We periodically clean out old transaction ID mappings
+ self._clock.looping_call(
+ run_as_background_process,
+ 5 * 60 * 1000,
+ "_cleanup_old_transaction_ids",
+ self._cleanup_old_transaction_ids,
+ )
+
self._get_event_cache = Cache(
"*getEvent*",
keylen=3,
@@ -522,7 +537,11 @@ class EventsWorkerStore(SQLBaseStore):
if not event_list:
single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
self._event_fetch_ongoing -= 1
return
else:
@@ -712,6 +731,7 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
+ original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
event_map[event_id] = original_ev
@@ -779,6 +799,8 @@ class EventsWorkerStore(SQLBaseStore):
* event_id (str)
+ * stream_ordering (int): stream ordering for this event
+
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
@@ -811,13 +833,15 @@ class EventsWorkerStore(SQLBaseStore):
sql = """\
SELECT
e.event_id,
- e.internal_metadata,
- e.json,
- e.format_version,
+ e.stream_ordering,
+ ej.internal_metadata,
+ ej.json,
+ ej.format_version,
r.room_version,
rej.reason
- FROM event_json as e
- LEFT JOIN rooms r USING (room_id)
+ FROM events AS e
+ JOIN event_json AS ej USING (event_id)
+ LEFT JOIN rooms r ON r.room_id = e.room_id
LEFT JOIN rejections as rej USING (event_id)
WHERE """
@@ -831,11 +855,12 @@ class EventsWorkerStore(SQLBaseStore):
event_id = row[0]
event_dict[event_id] = {
"event_id": event_id,
- "internal_metadata": row[1],
- "json": row[2],
- "format_version": row[3],
- "room_version_id": row[4],
- "rejected_reason": row[5],
+ "stream_ordering": row[1],
+ "internal_metadata": row[2],
+ "json": row[3],
+ "format_version": row[4],
+ "room_version_id": row[5],
+ "rejected_reason": row[6],
"redactions": [],
}
@@ -1017,16 +1042,12 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
- self, last_id: int, current_id: int, limit: int
+ self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
"""Returns new events, for the Events replication stream
@@ -1050,10 +1071,11 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " AND instance_name = ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
return txn.fetchall()
return await self.db_pool.runInteraction(
@@ -1061,7 +1083,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_ex_outlier_stream_rows(
- self, last_id: int, current_id: int
+ self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
@@ -1080,16 +1102,17 @@ class EventsWorkerStore(SQLBaseStore):
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
+ " INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
+ " AND out.instance_name = ?"
" ORDER BY event_stream_ordering ASC"
)
- txn.execute(sql, (last_id, current_id))
+ txn.execute(sql, (last_id, current_id, instance_name))
return txn.fetchall()
return await self.db_pool.runInteraction(
@@ -1102,6 +1125,9 @@ class EventsWorkerStore(SQLBaseStore):
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
+ NOTE: The IDs given here are from replication, and so should be
+ *positive*.
+
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
@@ -1132,10 +1158,11 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " AND instance_name = ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
- txn.execute(sql, (-last_id, -current_id, limit))
+ txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates = [(row[0], row[1:]) for row in txn]
limited = False
@@ -1149,15 +1176,16 @@ class EventsWorkerStore(SQLBaseStore):
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
+ " INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
+ " AND out.instance_name = ?"
" ORDER BY event_stream_ordering DESC"
)
- txn.execute(sql, (-last_id, -upper_bound))
+ txn.execute(sql, (-last_id, -upper_bound, instance_name))
new_event_updates.extend((row[0], row[1:]) for row in txn)
if len(new_event_updates) >= limit:
@@ -1171,7 +1199,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_all_updated_current_state_deltas(
- self, from_token: int, to_token: int, target_row_count: int
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
@@ -1197,9 +1225,10 @@ class EventsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, target_row_count))
+ txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
return txn.fetchall()
def get_deltas_for_stream_id_txn(txn, stream_id):
@@ -1287,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
+
+ async def get_event_id_from_transaction_id(
+ self, room_id: str, user_id: str, token_id: int, txn_id: str
+ ) -> Optional[str]:
+ """Look up if we have already persisted an event for the transaction ID,
+ returning the event ID if so.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="event_txn_id",
+ keyvalues={
+ "room_id": room_id,
+ "user_id": user_id,
+ "token_id": token_id,
+ "txn_id": txn_id,
+ },
+ retcol="event_id",
+ allow_none=True,
+ desc="get_event_id_from_transaction_id",
+ )
+
+ async def get_already_persisted_events(
+ self, events: Iterable[EventBase]
+ ) -> Dict[str, str]:
+ """Look up if we have already persisted an event for the transaction ID,
+ returning a mapping from event ID in the given list to the event ID of
+ an existing event.
+
+ Also checks if there are duplicates in the given events, if there are
+ will map duplicates to the *first* event.
+ """
+
+ mapping = {}
+ txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
+
+ for event in events:
+ token_id = getattr(event.internal_metadata, "token_id", None)
+ txn_id = getattr(event.internal_metadata, "txn_id", None)
+
+ if token_id and txn_id:
+ # Check if this is a duplicate of an event in the given events.
+ existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
+ if existing:
+ mapping[event.event_id] = existing
+ continue
+
+ # Check if this is a duplicate of an event we've already
+ # persisted.
+ existing = await self.get_event_id_from_transaction_id(
+ event.room_id, event.sender, token_id, txn_id
+ )
+ if existing:
+ mapping[event.event_id] = existing
+ txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
+ else:
+ txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
+
+ return mapping
+
+ async def _cleanup_old_transaction_ids(self):
+ """Cleans out transaction id mappings older than 24hrs.
+ """
+
+ def _cleanup_old_transaction_ids_txn(txn):
+ sql = """
+ DELETE FROM event_txn_id
+ WHERE inserted_ts < ?
+ """
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ txn.execute(sql, (one_day_ago,))
+
+ return await self.db_pool.runInteraction(
+ "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+ )
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92099f95ce..0acf0617ca 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,15 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import calendar
+import logging
+import time
+from typing import Dict
from synapse.metrics import GaugeBucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+logger = logging.getLogger(__name__)
+
# Collect metrics on the number of forward extremities that exist.
_extremities_collecter = GaugeBucketCollector(
"synapse_forward_extremities",
@@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
- def read_forward_extremities():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "read_forward_extremities", self._read_forward_extremities
- )
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
- hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+ @wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
def fetch(txn):
txn.execute(
@@ -137,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+ async def count_daily_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 24 hours.
+ """
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return await self.db_pool.runInteraction(
+ "count_daily_users", self._count_users, yesterday
+ )
+
+ async def count_monthly_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return await self.db_pool.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
+
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ (count,) = txn.fetchone()
+ return count
+
+ async def count_r30_users(self) -> Dict[str, int]:
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
+ """
+
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] == "unknown":
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ (count,) = txn.fetchone()
+ results["all"] = count
+
+ return results
+
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = time.gmtime()
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+ return today_start * 1000
+
+ @wrap_as_background_process("generate_user_daily_visits")
+ async def generate_user_daily_visits(self) -> None:
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self._clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
+ self._last_user_visit_update = today_start
+
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ await self.db_pool.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e93aad33cd..d788dc0fc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@
import logging
from typing import Dict, List
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
@@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._max_mau_value = hs.config.max_mau_value
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
desc="user_last_seen_monthly_active",
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_stats_only = hs.config.mau_stats_only
- self._max_mau_value = hs.config.max_mau_value
-
- # Do not add more reserved users than the total allowable number
- # cur = LoggingTransaction(
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
- def _initialise_reserved_users(self, txn, threepids):
- """Ensures that reserved threepids are accounted for in the MAU table, should
- be called on start up.
-
- Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
- """
-
- # XXX what is this function trying to achieve? It upserts into
- # monthly_active_users for each *registered* reserved mau user, but why?
- #
- # - shouldn't there already be an entry for each reserved user (at least
- # if they have been active recently)?
- #
- # - if it's important that the timestamp is kept up to date, why do we only
- # run this at startup?
-
- for tp in threepids:
- user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
- if user_id:
- is_support = self.is_support_user_txn(txn, user_id)
- if not is_support:
- # We do this manually here to avoid hitting #6791
- self.db_pool.simple_upsert_txn(
- txn,
- table="monthly_active_users",
- keyvalues={"user_id": user_id},
- values={"timestamp": int(self._clock.time_msec())},
- )
- else:
- logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
+ @wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
@@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._mau_stats_only = hs.config.mau_stats_only
+
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
+ def _initialise_reserved_users(self, txn, threepids):
+ """Ensures that reserved threepids are accounted for in the MAU table, should
+ be called on start up.
+
+ Args:
+ txn (cursor):
+ threepids (list[dict]): List of threepid dicts to reserve
+ """
+
+ # XXX what is this function trying to achieve? It upserts into
+ # monthly_active_users for each *registered* reserved mau user, but why?
+ #
+ # - shouldn't there already be an entry for each reserved user (at least
+ # if they have been active recently)?
+ #
+ # - if it's important that the timestamp is kept up to date, why do we only
+ # run this at startup?
+
+ for tp in threepids:
+ user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+ if user_id:
+ is_support = self.is_support_user_txn(txn, user_id)
+ if not is_support:
+ # We do this manually here to avoid hitting #6791
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
+ )
+ else:
+ logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..1681caa1f0 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
-
-class ProfileStore(ProfileWorkerStore):
- async def add_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
- ) -> None:
- """Ensure we are caching the remote user's profiles.
-
- This should only be called when `is_subscribed_remote_profile_for_user`
- would return true for the user.
- """
- await self.db_pool.simple_upsert(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- values={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="add_remote_profile_cache",
- )
-
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
@@ -138,6 +117,31 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
+ async def is_subscribed_remote_profile_for_user(self, user_id):
+ """Check whether we are interested in a remote user's profile.
+ """
+ res = await self.db_pool.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ return True
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ return True
+
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
@@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore):
_get_remote_profile_cache_entries_that_expire_txn,
)
- async def is_subscribed_remote_profile_for_user(self, user_id):
- """Check whether we are interested in a remote user's profile.
- """
- res = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
- if res:
- return True
+class ProfileStore(ProfileWorkerStore):
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
+ """Ensure we are caching the remote user's profiles.
- res = await self.db_pool.simple_select_one_onecol(
- table="group_invites",
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ await self.db_pool.simple_upsert(
+ table="remote_profile_cache",
keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
)
-
- if res:
- return True
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..4c843b7679 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +14,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
@@ -48,6 +47,18 @@ class RegistrationWorkerStore(SQLBaseStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
+ self._account_validity = hs.config.account_validity
+ if hs.config.run_background_tasks and self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ if hs.config.run_background_tasks:
+ self.clock.looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+ )
+
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
@@ -778,6 +789,105 @@ class RegistrationWorkerStore(SQLBaseStore):
"delete_threepid_session", delete_threepid_session_txn
)
+ @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+ async def cull_expired_threepid_validation_tokens(self) -> None:
+ """Remove threepid validation tokens with expiry dates that have passed"""
+
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self.clock.time_msec(),
+ )
+
+ @wrap_as_background_process("account_validity_set_expiration_dates")
+ async def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db_pool.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ await self.db_pool.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
+
+ async def get_user_pending_deactivation(self) -> Optional[str]:
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
+
+ async def del_user_pending_deactivation(self, user_id: str) -> None:
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ await self.db_pool.simple_delete(
+ "users_pending_deactivation",
+ keyvalues={"user_id": user_id},
+ desc="del_user_pending_deactivation",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -911,35 +1021,15 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
-
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
async def add_access_token_to_user(
self,
user_id: str,
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
- ) -> None:
+ ) -> int:
"""Adds an access token for the given user.
Args:
@@ -949,6 +1039,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
+ Returns:
+ The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
@@ -964,6 +1056,38 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
+ return next_id
+
+ def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn, "access_tokens", {"token": token}, "device_id"
+ )
+
+ self.db_pool.simple_update_txn(
+ txn, "access_tokens", {"token": token}, {"device_id": device_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+ return old_device_id
+
+ async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+ """Sets the device ID associated with an access token.
+
+ Args:
+ token: The access token to modify.
+ device_id: The new device ID.
+ Returns:
+ The old device ID associated with the access token.
+ """
+
+ return await self.db_pool.runInteraction(
+ "set_device_for_access_token",
+ self._set_device_for_access_token_txn,
+ token,
+ device_id,
+ )
+
async def register_user(
self,
user_id: str,
@@ -1121,7 +1245,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+ async def user_set_password_hash(
+ self, user_id: str, password_hash: Optional[str]
+ ) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1271,32 +1397,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_user_pending_deactivation",
)
- async def del_user_pending_deactivation(self, user_id: str) -> None:
- """
- Removes the given user to the table of users who need to be parted from all the
- rooms they're in, effectively marking that user as fully deactivated.
- """
- # XXX: This should be simple_delete_one but we failed to put a unique index on
- # the table, so somehow duplicate entries have ended up in it.
- await self.db_pool.simple_delete(
- "users_pending_deactivation",
- keyvalues={"user_id": user_id},
- desc="del_user_pending_deactivation",
- )
-
- async def get_user_pending_deactivation(self) -> Optional[str]:
- """
- Gets one user from the table of users waiting to be parted from all the rooms
- they're in.
- """
- return await self.db_pool.simple_select_one_onecol(
- "users_pending_deactivation",
- keyvalues={},
- retcol="user_id",
- allow_none=True,
- desc="get_users_pending_deactivation",
- )
-
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
@@ -1447,22 +1547,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
- )
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1492,61 +1576,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- async def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- await self.db_pool.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self.db_pool.simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id},
- values={"expiration_ts_ms": expiration_ts, "email_sent": False},
- )
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3c7630857f..e83d961c20 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
"count_public_rooms", _count_public_rooms_txn
)
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
+ """
+
+ def f(txn):
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
+
+ return await self.db_pool.runInteraction("get_rooms", f)
+
async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
@@ -857,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
+ async def get_rooms_for_retention_period_in_range(
+ self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+ ) -> Dict[str, dict]:
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms: Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms: Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null: Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.db_pool.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.db_pool.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ return await self.db_pool.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1292,18 +1387,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- async def get_room_count(self) -> int:
- """Retrieve the total number of rooms.
- """
-
- def f(txn):
- sql = "SELECT count(*) FROM rooms"
- txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
-
- return await self.db_pool.runInteraction("get_rooms", f)
-
async def add_event_report(
self,
room_id: str,
@@ -1446,88 +1529,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
-
- async def get_rooms_for_retention_period_in_range(
- self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, dict]:
- """Retrieves all of the rooms within the given retention range.
-
- Optionally includes the rooms which don't have a retention policy.
-
- Args:
- min_ms: Duration in milliseconds that define the lower limit of
- the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms: Duration in milliseconds that define the upper limit of
- the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null: Whether to include rooms which retention policy is NULL
- in the returned set.
-
- Returns:
- The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
- """
-
- def get_rooms_for_retention_period_in_range_txn(txn):
- range_conditions = []
- args = []
-
- if min_ms is not None:
- range_conditions.append("max_lifetime > ?")
- args.append(min_ms)
-
- if max_ms is not None:
- range_conditions.append("max_lifetime <= ?")
- args.append(max_ms)
-
- # Do a first query which will retrieve the rooms that have a retention policy
- # in their current state.
- sql = """
- SELECT room_id, min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- """
-
- if len(range_conditions):
- sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
- if include_null:
- sql += " OR max_lifetime IS NULL"
-
- txn.execute(sql, args)
-
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
-
- if include_null:
- # If required, do a second query that retrieves all of the rooms we know
- # of so we can handle rooms with no retention policy.
- sql = "SELECT DISTINCT room_id FROM current_state_events"
-
- txn.execute(sql)
-
- rows = self.db_pool.cursor_to_dict(txn)
-
- # If a room isn't already in the dict (i.e. it doesn't have a retention
- # policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
-
- return rooms_dict
-
- rooms = await self.db_pool.runInteraction(
- "get_rooms_for_retention_period_in_range",
- get_rooms_for_retention_period_in_range_txn,
- )
-
- return rooms
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..20fcdaa529 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,12 +21,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
- LoggingTransaction,
- SQLBaseStore,
- db_to_json,
- make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
@@ -60,15 +55,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# background update still running?
self._current_state_events_membership_up_to_date = False
- txn = LoggingTransaction(
- db_conn.cursor(),
- name="_check_safe_current_state_events_membership_updated",
- database_engine=self.database_engine,
+ txn = db_conn.cursor(
+ txn_name="_check_safe_current_state_events_membership_updated"
)
self._check_safe_current_state_events_membership_updated_txn(txn)
txn.close()
- if self.hs.config.metrics_flags.known_servers:
+ if (
+ self.hs.config.run_background_tasks
+ and self.hs.config.metrics_flags.known_servers
+ ):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
run_as_background_process,
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_token, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_origin_server_ts", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
- database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
- % (",".join("?" for _ in chunk),)
- ),
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),),
[as_id] + chunk,
)
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row = list(row)
row[12] = token_to_stream_ordering(row[12])
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_stream_ordering, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_stream_ordering, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search_order", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_fields_sender_url", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
- database_engine.convert_param_style(
- "UPDATE remote_media_cache SET last_access_ts = ?"
- ),
- (int(time.time() * 1000),),
+ "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
)
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
import logging
+from io import StringIO
from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
logger = logging.getLogger(__name__)
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
select_clause,
)
- if isinstance(database_engine, PostgresEngine):
- cur.execute(sql)
- else:
- cur.executescript(sql)
+ execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ?
"""
- sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("%:" + config.server_name,))
cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+ user_id TEXT NOT NULL PRIMARY KEY,
+ device_id TEXT NOT NULL,
+ device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+ user_id TEXT NOT NULL, -- The user this fallback key is for.
+ device_id TEXT NOT NULL, -- The device this fallback key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+ CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+ instance_id SERIAL PRIMARY KEY,
+ instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
new file mode 100644
index 0000000000..b2454121a8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
@@ -0,0 +1,40 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A map of recent events persisted with transaction IDs. Used to deduplicate
+-- send event requests with the same transaction ID.
+--
+-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
+-- used to make the request.
+--
+-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
+-- events or access token we don't want to try and de-duplicate the event.
+CREATE TABLE IF NOT EXISTS event_txn_id (
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ token_id BIGINT NOT NULL,
+ txn_id TEXT NOT NULL,
+ inserted_ts BIGINT NOT NULL,
+ FOREIGN KEY (event_id)
+ REFERENCES events (event_id) ON DELETE CASCADE,
+ FOREIGN KEY (token_id)
+ REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
new file mode 100644
index 0000000000..ad1f481428
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 1d27439536..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
)
+def _filter_results(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ topological_ordering: int,
+ stream_ordering: int,
+) -> bool:
+ """Returns True if the event persisted by the given instance at the given
+ topological/stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+
+ event_historical_tuple = (
+ topological_ordering,
+ stream_ordering,
+ )
+
+ if lower_token:
+ if lower_token.topological is not None:
+ # If these are historical tokens we compare the `(topological, stream)`
+ # tuples.
+ if event_historical_tuple <= lower_token.as_historical_tuple():
+ return False
+
+ else:
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(
+ instance_name
+ ):
+ return False
+
+ if upper_token:
+ if upper_token.topological is not None:
+ if upper_token.as_historical_tuple() < event_historical_tuple:
+ return False
+ else:
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.get_room_max_stream_ordering())
+ """Get a `RoomStreamToken` that marks the current maximum persisted
+ position of the events stream. Useful to get a token that represents
+ "now".
+
+ The token returned is a "live" token that may have an instance_map
+ component.
+ """
+
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT event_id, instance_name, topological_ordering, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ order,
+ )
+ txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ][:limit]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=False)
if order.lower() == "desc":
ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ ORDER BY e.stream_ordering ASC
+ """
+ txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ]
return rows
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
)
return "t%d-%d" % (topo, token)
- async def get_stream_id_for_event(self, event_id: str) -> int:
- """The stream ID for an event
- Args:
- event_id: The id of the event to look up a stream token for.
- Raises:
- StoreError if the event wasn't in the database.
- Returns:
- A stream ID.
- """
- return await self.db_pool.runInteraction(
- "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
- )
-
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
@@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ if from_token.topological is not None:
+ from_bound = (
+ from_token.as_historical_tuple()
+ ) # type: Tuple[Optional[int], int]
+ elif direction == "b":
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
- args.append(int(limit))
+ # We fetch more events as we'll filter the result set
+ args.append(int(limit) * 2)
select_keywords = "SELECT"
join_clause = ""
@@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords += "DISTINCT"
sql = """
- %(select_keywords)s event_id, topological_ordering, stream_ordering
+ %(select_keywords)s
+ event_id, instance_name,
+ topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
txn.execute(sql, args)
- rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+ # Filter the result set.
+ rows = [
+ _EventDictReturn(event_id, topological_ordering, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ lower_token=to_token if direction == "b" else from_token,
+ upper_token=from_token if direction == "b" else to_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ )
+ ][:limit]
if rows:
topo = rows[-1].topological_ordering
@@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return (events, token)
+ @cached()
+ async def get_id_for_instance(self, instance_name: str) -> int:
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+ """
+
+ def _get_id_for_instance_txn(txn):
+ instance_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ allow_none=True,
+ )
+ if instance_id is not None:
+ return instance_id
+
+ # If we don't have an entry upsert one.
+ #
+ # We could do this before the first check, and rely on the cache for
+ # efficiency, but each UPSERT causes the next ID to increment which
+ # can quickly bloat the size of the generated IDs for new instances.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ values={},
+ )
+
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ )
+
+ return await self.db_pool.runInteraction(
+ "get_id_for_instance", _get_id_for_instance_txn
+ )
+
+ @cached()
+ async def get_name_from_instance_id(self, instance_id: int) -> str:
+ """Get the instance name from an ID previously returned by
+ `get_id_for_instance`.
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="instance_map",
+ keyvalues={"instance_id": instance_id},
+ retcol="instance_name",
+ desc="get_name_from_instance_id",
+ )
+
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 97aed1500e..7d46090267 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple(
SENTINEL = object()
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+ @wrap_as_background_process("cleanup_transactions")
+ async def _cleanup_transactions(self) -> None:
+ now = self._clock.time_msec()
+ month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+ def _cleanup_transactions_txn(txn):
+ txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+ await self.db_pool.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
+
+
+class TransactionStore(TransactionWorkerStore):
"""A collection of queries for handling PDUs.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
-
self._destination_retry_cache = ExpiringCache(
cache_name="get_destination_retry_timings",
clock=self._clock,
@@ -266,22 +284,6 @@ class TransactionStore(SQLBaseStore):
},
)
- def _start_cleanup_transactions(self):
- return run_as_background_process(
- "cleanup_transactions", self._cleanup_transactions
- )
-
- async def _cleanup_transactions(self) -> None:
- now = self._clock.time_msec()
- month_ago = now - 30 * 24 * 60 * 60 * 1000
-
- def _cleanup_transactions_txn(txn):
- txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
- await self.db_pool.runInteraction(
- "_cleanup_transactions", _cleanup_transactions_txn
- )
-
async def store_destination_rooms_entries(
self, destinations: Iterable[str], room_id: str, stream_ordering: int,
) -> None:
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
)
return [(row["user_agent"], row["ip"]) for row in rows]
-
-class UIAuthStore(UIAuthWorkerStore):
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
iterable=session_ids,
keyvalues={},
)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ pass
|