diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index eb1a7e5002..59f3394b0a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -90,8 +90,10 @@ class BackgroundUpdater(object):
self._clock = hs.get_clock()
self.db = database
+ # if a background update is currently running, its name.
+ self._current_background_update = None # type: Optional[str]
+
self._background_update_performance = {}
- self._background_update_queue = []
self._background_update_handlers = {}
self._all_done = False
@@ -111,7 +113,7 @@ class BackgroundUpdater(object):
except Exception:
logger.exception("Error doing update")
else:
- if result is None:
+ if result:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
@@ -119,26 +121,25 @@ class BackgroundUpdater(object):
self._all_done = True
return None
- @defer.inlineCallbacks
- def has_completed_background_updates(self):
+ async def has_completed_background_updates(self) -> bool:
"""Check if all the background updates have completed
Returns:
- Deferred[bool]: True if all background updates have completed
+ True if all background updates have completed
"""
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
return True
- # obviously, if we have things in our queue, we're not done.
- if self._background_update_queue:
+ # obviously, if we are currently processing an update, we're not done.
+ if self._current_background_update:
return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = yield self.db.simple_select_onecol(
+ updates = await self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@@ -153,11 +154,10 @@ class BackgroundUpdater(object):
async def has_completed_background_update(self, update_name) -> bool:
"""Check if the given background update has finished running.
"""
-
if self._all_done:
return True
- if update_name in self._background_update_queue:
+ if update_name == self._current_background_update:
return False
update_exists = await self.db.simple_select_one_onecol(
@@ -170,9 +170,7 @@ class BackgroundUpdater(object):
return not update_exists
- async def do_next_background_update(
- self, desired_duration_ms: float
- ) -> Optional[int]:
+ async def do_next_background_update(self, desired_duration_ms: float) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
@@ -181,33 +179,51 @@ class BackgroundUpdater(object):
desired_duration_ms(float): How long we want to spend
updating.
Returns:
- None if there is no more work to do, otherwise an int
+ True if we have finished running all the background updates, otherwise False
"""
- if not self._background_update_queue:
- updates = await self.db.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=("update_name", "depends_on"),
+
+ def get_background_updates_txn(txn):
+ txn.execute(
+ """
+ SELECT update_name, depends_on FROM background_updates
+ ORDER BY ordering, update_name
+ """
)
- in_flight = {update["update_name"] for update in updates}
- for update in updates:
- if update["depends_on"] not in in_flight:
- self._background_update_queue.append(update["update_name"])
+ return self.db.cursor_to_dict(txn)
- if not self._background_update_queue:
- # no work left to do
- return None
+ if not self._current_background_update:
+ all_pending_updates = await self.db.runInteraction(
+ "background_updates", get_background_updates_txn,
+ )
+ if not all_pending_updates:
+ # no work left to do
+ return True
+
+ # find the first update which isn't dependent on another one in the queue.
+ pending = {update["update_name"] for update in all_pending_updates}
+ for upd in all_pending_updates:
+ depends_on = upd["depends_on"]
+ if not depends_on or depends_on not in pending:
+ break
+ logger.info(
+ "Not starting on bg update %s until %s is done",
+ upd["update_name"],
+ depends_on,
+ )
+ else:
+ # if we get to the end of that for loop, there is a problem
+ raise Exception(
+ "Unable to find a background update which doesn't depend on "
+ "another: dependency cycle?"
+ )
- # pop from the front, and add back to the back
- update_name = self._background_update_queue.pop(0)
- self._background_update_queue.append(update_name)
+ self._current_background_update = upd["update_name"]
- res = await self._do_background_update(update_name, desired_duration_ms)
- return res
+ await self._do_background_update(desired_duration_ms)
+ return False
- async def _do_background_update(
- self, update_name: str, desired_duration_ms: float
- ) -> int:
+ async def _do_background_update(self, desired_duration_ms: float) -> int:
+ update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@@ -400,27 +416,6 @@ class BackgroundUpdater(object):
self.register_background_update_handler(update_name, updater)
- def start_background_update(self, update_name, progress):
- """Starts a background update running.
-
- Args:
- update_name: The update to set running.
- progress: The initial state of the progress of the update.
-
- Returns:
- A deferred that completes once the task has been added to the
- queue.
- """
- # Clear the background update queue so that we will pick up the new
- # task on the next iteration of do_background_update.
- self._background_update_queue = []
- progress_json = json.dumps(progress)
-
- return self.db.simple_insert(
- "background_updates",
- {"update_name": update_name, "progress_json": progress_json},
- )
-
def _end_background_update(self, update_name):
"""Removes a completed background update task from the queue.
@@ -429,9 +424,12 @@ class BackgroundUpdater(object):
Returns:
A deferred that completes once the task is removed.
"""
- self._background_update_queue = [
- name for name in self._background_update_queue if name != update_name
- ]
+ if update_name != self._current_background_update:
+ raise Exception(
+ "Cannot end background update %s which isn't currently running"
+ % update_name
+ )
+ self._current_background_update = None
return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
db_conn,
"device_lists_stream",
"stream_id",
- extra_tables=[("user_signature_stream", "stream_id")],
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
},
)
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
-
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index e1ccb27142..92bc06919b 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_tuple_comparison_clause
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
@@ -303,16 +303,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# we'll just end up updating the same device row multiple
# times, which is fine.
- if self.database_engine.supports_tuple_comparison:
- where_clause = "(user_id, device_id) > (?, ?)"
- where_args = [last_user_id, last_device_id]
- else:
- # We explicitly do a `user_id >= ? AND (...)` here to ensure
- # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
- # makes it hard for query optimiser to tell that it can use the
- # index on user_id
- where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
- where_args = [last_user_id, last_user_id, last_device_id]
+ where_clause, where_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [("user_id", last_user_id), ("device_id", last_device_id)],
+ )
sql = """
SELECT
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
+ def get_all_new_device_messages(self, last_pos, current_pos, limit):
+ """
+ Args:
+ last_pos(int):
+ current_pos(int):
+ limit(int):
+ Returns:
+ A deferred list of rows from the device inbox
+ """
+ if last_pos == current_pos:
+ return defer.succeed([])
+
+ def get_all_new_device_messages_txn(txn):
+ # We limit like this as we might have multiple rows per stream_id, and
+ # we want to make sure we always get all entries for any stream_id
+ # we return.
+ upper_pos = min(current_pos, last_pos + limit)
+ sql = (
+ "SELECT max(stream_id), user_id"
+ " FROM device_inbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY user_id"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows = txn.fetchall()
+
+ sql = (
+ "SELECT max(stream_id), destination"
+ " FROM device_federation_outbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY destination"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows.extend(txn)
+
+ # Order by ascending stream ordering
+ rows.sort()
+
+ return rows
+
+ return self.db.runInteraction(
+ "get_all_new_device_messages", get_all_new_device_messages_txn
+ )
+
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
-
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
- Args:
- last_pos(int):
- current_pos(int):
- limit(int):
- Returns:
- A deferred list of rows from the device inbox
- """
- if last_pos == current_pos:
- return defer.succeed([])
-
- def get_all_new_device_messages_txn(txn):
- # We limit like this as we might have multiple rows per stream_id, and
- # we want to make sure we always get all entries for any stream_id
- # we return.
- upper_pos = min(current_pos, last_pos + limit)
- sql = (
- "SELECT max(stream_id), user_id"
- " FROM device_inbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY user_id"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
-
- sql = (
- "SELECT max(stream_id), destination"
- " FROM device_federation_outbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY destination"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
-
- # Order by ascending stream ordering
- rows.sort()
-
- return rows
-
- return self.db.runInteraction(
- "get_all_new_device_messages", get_all_new_device_messages_txn
- )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 8af5f7de54..ee3a2ab031 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,11 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import (
+ Database,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -40,6 +45,7 @@ from synapse.util.caches.descriptors import (
cachedList,
)
from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
logger = logging.getLogger(__name__)
@@ -47,6 +53,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
"drop_device_list_streams_non_unique_indexes"
)
+BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -112,23 +120,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- # We retrieve n+1 devices from the list of outbound pokes where n is
- # our outbound device update limit. We then check if the very last
- # device has the same stream_id as the second-to-last device. If so,
- # then we ignore all devices with that stream_id and only send the
- # devices with a lower stream_id.
- #
- # If when culling the list we end up with no devices afterwards, we
- # consider the device update to be too large, and simply skip the
- # stream_id; the rationale being that such a large device list update
- # is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
- limit + 1,
+ limit,
)
# Return an empty list if there are no updates
@@ -166,14 +164,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- # if we have exceeded the limit, we need to exclude any results with the
- # same stream_id as the last row.
- if len(updates) > limit:
- stream_id_cutoff = updates[-1][2]
- now_stream_id = stream_id_cutoff - 1
- else:
- stream_id_cutoff = None
-
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -181,7 +171,6 @@ class DeviceWorkerStore(SQLBaseStore):
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> (stream_id, opentracing_context)
- # as long as their stream_id does not match that of the last row
#
# opentracing_context contains the opentracing metadata for the request
# that created the poke
@@ -192,10 +181,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
- # Stop processing updates
- break
-
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +203,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # That should only happen if a client is spamming the server with new
- # devices, in which case E2E isn't going to work well anyway. We'll just
- # skip that stream_id and return an empty list, and continue with the next
- # stream_id next time.
- if not query_map and not cross_signing_keys_by_user:
- return stream_id_cutoff, []
-
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -259,11 +233,11 @@ class DeviceWorkerStore(SQLBaseStore):
# get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
- WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id
LIMIT ?
"""
- txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
+ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
return list(txn)
@@ -301,7 +275,14 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = yield self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id
)
- for device_id, device in iteritems(user_devices):
+
+ # make sure we go through the devices in stream order
+ device_ids = sorted(
+ user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+ )
+
+ for device_id in device_ids:
+ device = user_devices[device_id]
stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -611,22 +592,33 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int, limit: int,
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
+ LIMIT ?
"""
- return self.db.execute(
- "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+ return await self.db.execute(
+ "get_all_device_list_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
@cached(max_entries=10000)
@@ -728,6 +720,11 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._drop_device_list_streams_non_unique_indexes,
)
+ # clear out duplicate device list outbound pokes
+ self.db.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -742,6 +739,66 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
return 1
+ async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+ # for some reason, we have accumulated duplicate entries in
+ # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+ # efficient.
+ #
+ # For each duplicate, we delete all the existing rows and put one back.
+
+ KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
+ last_row = progress.get(
+ "last_row",
+ {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
+ )
+
+ def _txn(txn):
+ clause, args = make_tuple_comparison_clause(
+ self.db.engine, [(x, last_row[x]) for x in KEY_COLS]
+ )
+ sql = """
+ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
+ FROM device_lists_outbound_pokes
+ WHERE %s
+ GROUP BY %s
+ HAVING count(*) > 1
+ ORDER BY %s
+ LIMIT ?
+ """ % (
+ clause, # WHERE
+ ",".join(KEY_COLS), # GROUP BY
+ ",".join(KEY_COLS), # ORDER BY
+ )
+ txn.execute(sql, args + [batch_size])
+ rows = self.db.cursor_to_dict(txn)
+
+ row = None
+ for row in rows:
+ self.db.simple_delete_txn(
+ txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+ )
+
+ row["sent"] = False
+ self.db.simple_insert_txn(
+ txn, "device_lists_outbound_pokes", row,
+ )
+
+ if row:
+ self.db.updates._background_update_progress_txn(
+ txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+ )
+
+ return len(rows)
+
+ rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn)
+
+ if not rows:
+ await self.db.updates._end_background_update(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
+ )
+
+ return rows
+
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
@@ -1021,29 +1078,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ yield self.db.runInteraction(
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
yield self.db.runInteraction(
- "add_device_change_to_streams",
- self._add_device_change_txn,
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
- stream_id,
+ stream_ids,
+ context,
)
- return stream_id
- def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
- now = self._clock.time_msec()
+ return stream_ids[-1]
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_id,
- )
+
+ min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1052,7 +1129,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1060,11 +1137,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
- for device_id in device_ids
+ for stream_id, device_id in zip(stream_ids, device_ids)
],
)
- context = get_active_span_text_map()
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1072,7 +1160,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
- "stream_id": stream_id,
+ "stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
@@ -1086,18 +1174,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self):
+ def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers. We keep one entry per
- (destination, user_id) tuple to ensure that the prev_ids remain correct
- if the server does come back.
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
"""
- yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ yesterday = self._clock.time_msec() - prune_age
def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
select_sql = """
- SELECT destination, user_id, max(stream_id) as stream_id
- FROM device_lists_outbound_pokes
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
@@ -1108,14 +1225,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not rows:
return
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
delete_sql = """
DELETE FROM device_lists_outbound_pokes
- WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
"""
-
- txn.executemany(
- delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
- )
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
# Since we've deleted unsent deltas, we need to remove the entry
# of last successful sent so that the prev_ids are correctly set.
@@ -1125,7 +1257,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""
txn.executemany(sql, ((row[0], row[1]) for row in rows))
- logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+ logger.info("Pruned %d device list outbound pokes", count)
return run_as_background_process(
"prune_old_outbound_device_pokes",
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index c9e7de7d12..e1d1bc3e05 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -14,6 +14,7 @@
# limitations under the License.
from collections import namedtuple
+from typing import Optional
from twisted.internet import defer
@@ -159,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+ def update_aliases_for_room(
+ self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+ ):
+ """Repoint all of the aliases for a given room, to a different room.
+
+ Args:
+ old_room_id:
+ new_room_id:
+ creator: The user to record as the creator of the new mapping.
+ If None, the creator will be left unchanged.
+ """
+
def _update_aliases_for_room_txn(txn):
- sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
- txn.execute(sql, (new_room_id, creator, old_room_id))
+ update_creator_sql = ""
+ sql_params = (new_room_id, old_room_id)
+ if creator:
+ update_creator_sql = ", creator = ?"
+ sql_params = (new_room_id, creator, old_room_id)
+
+ sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % (
+ update_creator_sql,
+ )
+ txn.execute(sql, sql_params)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,)
)
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 84594cf0a9..23f4570c4b 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
- "is_verified": row["is_verified"],
+ # is_verified must be returned to the client as a boolean
+ "is_verified": bool(row["is_verified"]),
"session_data": json.loads(row["session_data"]),
}
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 001a53f9b4..bcf746b7ef 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
@@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
- SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id
+ ORDER BY stream_id ASC
+ LIMIT ?
"""
return self.db.execute(
- "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ "get_all_user_signature_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 62d4e9f599..b99439cc37 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -173,19 +173,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
for event_id in initial_events
}
+ # The sorted list of events whose auth chains we should walk.
+ search = [] # type: List[Tuple[int, str]]
+
# We need to get the depth of the initial events for sorting purposes.
sql = """
SELECT depth, event_id FROM events
WHERE %s
- ORDER BY depth ASC
"""
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", initial_events
- )
- txn.execute(sql % (clause,), args)
+ # the list can be huge, so let's avoid looking them all up in one massive
+ # query.
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
- # The sorted list of events whose auth chains we should walk.
- search = txn.fetchall() # type: List[Tuple[int, str]]
+ # I think building a temporary list with fetchall is more efficient than
+ # just `search.extend(txn)`, but this is unconfirmed
+ search.extend(txn.fetchall())
+
+ # sort by depth
+ search.sort()
# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d593ef47b8..e71c23541d 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1267,104 +1267,6 @@ class EventsStore(
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
- def get_current_events_token(self):
- """The current maximum token that events have reached"""
- return self._stream_id_gen.get_current_token()
-
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_forward_event_rows(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < event_stream_ordering"
- " AND event_stream_ordering <= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_id, upper_bound))
- new_event_updates.extend(txn)
-
- return new_event_updates
-
- return self.db.runInteraction(
- "get_all_new_forward_event_rows", get_all_new_forward_event_rows
- )
-
- def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_backfill_event_rows(txn):
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (-last_id, -current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_id, -upper_bound))
- new_event_updates.extend(txn.fetchall())
-
- return new_event_updates
-
- return self.db.runInteraction(
- "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
- )
-
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
@@ -1850,22 +1752,6 @@ class EventsStore(
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
- def get_all_updated_current_state_deltas_txn(txn):
- sql = """
- SELECT stream_id, room_id, type, state_key, event_id
- FROM current_state_delta_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC LIMIT ?
- """
- txn.execute(sql, (from_token, to_token, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_current_state_deltas",
- get_all_updated_current_state_deltas_txn,
- )
-
def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering
):
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ca237c6f12..accde349a7 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -35,7 +35,7 @@ from synapse.api.room_versions import (
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
@@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
- log_ctx = LoggingContext.current_context()
+ log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _get_events_from_db is also responsible for turning db rows
@@ -632,7 +632,7 @@ class EventsWorkerStore(SQLBaseStore):
event_map[event_id] = original_ev
- # finally, we can decide whether each one nededs redacting, and build
+ # finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+ return -self._backfill_id_gen.get_current_token()
+
+ def get_current_events_token(self):
+ """The current maximum token that events have reached"""
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_forward_event_rows(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < event_stream_ordering"
+ " AND event_stream_ordering <= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_id, upper_bound))
+ new_event_updates.extend(txn)
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ )
+
+ def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_backfill_event_rows(txn):
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (-last_id, -current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_id, -upper_bound))
+ new_event_updates.extend(txn.fetchall())
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+ )
+
+ def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
+ )
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 80ca36dedf..8aecd414c2 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -340,7 +340,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_expired_url_cache", _get_expired_url_cache_txn
)
- def delete_url_cache(self, media_ids):
+ async def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
@@ -349,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@@ -367,7 +367,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
- def delete_url_cache_media(self, media_ids):
+ async def delete_url_cache_media(self, media_ids):
if len(media_ids) == 0:
return
@@ -380,6 +380,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 604c8b7ddd..dab31e0c2d 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
- for state in presence_states
+ for stream_id, state in zip(stream_orderings, presence_states)
],
)
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
- def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
- sql = (
- "SELECT stream_id, user_id, state, last_active_ts,"
- " last_federation_update_ts, last_user_sync_ts, status_msg,"
- " currently_active"
- " FROM presence_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- )
- txn.execute(sql, (last_id, current_id))
+ sql = """
+ SELECT stream_id, user_id, state, last_active_ts,
+ last_federation_update_ts, last_user_sync_ts,
+ status_msg,
+ currently_active
+ FROM presence_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 62ac88d9f2..b3faafa0a4 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -41,6 +41,7 @@ def _load_rules(rawrules, enabled_map):
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
+ rule["default"] = False
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
@@ -333,6 +334,26 @@ class PushRulesWorkerStore(
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
+ )
+
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
@@ -684,26 +705,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_all_push_rule_updates(self, last_id, current_id, limit):
- """Get all the push rules changes that have happend on the server"""
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_push_rule_updates_txn(txn):
- sql = (
- "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
- " op, priority_class, priority, conditions, actions"
- " FROM push_rules_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_push_rule_updates", get_all_push_rule_updates_txn
- )
-
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..aaebe427d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = """
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (prev_id, current_id, limit))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
+
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = """
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """
-
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
-
- if prev_id == current_id:
- return defer.succeed([])
-
- return self.db.runInteraction(
- "get_all_new_public_rooms", get_all_new_public_rooms
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
new file mode 100644
index 0000000000..133d80af35
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- we no longer keep sent outbound device pokes in the db; clear them out
+-- so that we don't have to worry about them.
+--
+-- This is a sequence scan, but it doesn't take too long.
+
+DELETE FROM device_lists_outbound_pokes WHERE sent;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
new file mode 100644
index 0000000000..fdc39e9ba5
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ /* for some reason, we have accumulated duplicate entries in
+ * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+ * efficient.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json)
+ VALUES (5800, 'remove_dup_outbound_pokes', '{}');
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index ada5cce6c2..e89f0bffb5 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -481,11 +481,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id, limit, end_token
)
- logger.debug("stream before")
events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
- logger.debug("stream after")
self._set_before_and_after(events, rows)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e61595336c..a7cd97b0b0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -17,7 +17,17 @@
import logging
import time
from time import monotonic as monotonic_time
-from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@@ -32,6 +42,7 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import (
LoggingContext,
LoggingContextOrSentinel,
+ current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -483,7 +494,7 @@ class Database(object):
end = monotonic_time()
duration = end - start
- LoggingContext.current_context().add_database_transaction(duration)
+ current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@@ -510,7 +521,7 @@ class Database(object):
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
- if LoggingContext.current_context() == LoggingContext.sentinel:
+ if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
@@ -547,10 +558,8 @@ class Database(object):
Returns:
Deferred: The result of func
"""
- parent_context = (
- LoggingContext.current_context()
- ) # type: Optional[LoggingContextOrSentinel]
- if parent_context == LoggingContext.sentinel:
+ parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
+ if not parent_context:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
@@ -1558,3 +1567,74 @@ def make_in_list_sql_clause(
return "%s = ANY(?)" % (column,), [list(iterable)]
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+
+
+KV = TypeVar("KV")
+
+
+def make_tuple_comparison_clause(
+ database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
+) -> Tuple[str, List[KV]]:
+ """Returns a tuple comparison SQL clause
+
+ Depending what the SQL engine supports, builds a SQL clause that looks like either
+ "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
+
+ Args:
+ database_engine
+ keys: A set of (column, value) pairs to be compared.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+ if database_engine.supports_tuple_comparison:
+ return (
+ "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
+ [k[1] for k in keys],
+ )
+
+ # we want to build a clause
+ # (a > ?) OR
+ # (a == ? AND b > ?) OR
+ # (a == ? AND b == ? AND c > ?)
+ # ...
+ # (a == ? AND b == ? AND ... AND z > ?)
+ #
+ # or, equivalently:
+ #
+ # (a > ? OR (a == ? AND
+ # (b > ? OR (b == ? AND
+ # ...
+ # (y > ? OR (y == ? AND
+ # z > ?
+ # ))
+ # ...
+ # ))
+ # ))
+ #
+ # which itself is equivalent to (and apparently easier for the query optimiser):
+ #
+ # (a >= ? AND (a > ? OR
+ # (b >= ? AND (b > ? OR
+ # ...
+ # (y >= ? AND (y > ? OR
+ # z > ?
+ # ))
+ # ...
+ # ))
+ # ))
+ #
+ #
+
+ clause = ""
+ args = [] # type: List[KV]
+ for k, v in keys[:-1]:
+ clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
+ args.extend([v, v])
+
+ (k, v) = keys[-1]
+ clause += "%s > ?" % (k,)
+ args.append(v)
+
+ clause += "))" * (len(keys) - 1)
+ return clause, args
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6cb7d4b922..1712932f31 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 57
+SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/schema/delta/58/00background_update_ordering.sql b/synapse/storage/schema/delta/58/00background_update_ordering.sql
new file mode 100644
index 0000000000..02dae587cc
--- /dev/null
+++ b/synapse/storage/schema/delta/58/00background_update_ordering.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* add an "ordering" column to background_updates, which can be used to sort them
+ to achieve some level of consistency. */
+
+ALTER TABLE background_updates ADD COLUMN ordering INT NOT NULL DEFAULT 0;
|