diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a19d65ad23..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -262,13 +262,18 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+ """Similar to `executemany`, except `txn.rowcount` will not be correct
+ afterwards.
+
+ More efficient than `executemany` on PostgreSQL
+ """
+
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
- for val in args:
- self.execute(sql, val)
+ self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
@@ -888,7 +893,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
- txn.executemany(sql, vals)
+ txn.execute_batch(sql, vals)
async def simple_upsert(
self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
+ EventForwardExtremitiesStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 68896f34af..a277a1ef13 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -68,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
- if hs.get_instance_name() in hs.config.worker.writers.events:
+ if hs.get_instance_name() in hs.config.worker.writers.account_data:
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
+ txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about
# when the latest change happened.
- txn.executemany(
+ txn.execute_batch(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ ) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database.
Args:
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
_gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
- txn.executemany(
+ txn.execute_batch(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
- @staticmethod
+ @classmethod
def _add_chain_cover_index(
+ cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
@@ -614,60 +615,17 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for:
return
- # We now calculate the chain IDs/sequence numbers for the events. We
- # do this by looking at the chain ID and sequence number of any auth
- # event with the same type/state_key and incrementing the sequence
- # number by one. If there was no match or the chain ID/sequence
- # number is already taken we generate a new chain.
- #
- # We need to do this in a topologically sorted order as we want to
- # generate chain IDs/sequence numbers of an event's auth events
- # before the event itself.
- chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
- new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
- for event_id in sorted_topologically(
- events_to_calc_chain_id_for, event_to_auth_chain
- ):
- existing_chain_id = None
- for auth_id in event_to_auth_chain.get(event_id, []):
- if event_to_types.get(event_id) == event_to_types.get(auth_id):
- existing_chain_id = chain_map[auth_id]
- break
-
- new_chain_tuple = None
- if existing_chain_id:
- # We found a chain ID/sequence number candidate, check its
- # not already taken.
- proposed_new_id = existing_chain_id[0]
- proposed_new_seq = existing_chain_id[1] + 1
- if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
- already_allocated = db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_auth_chains",
- keyvalues={
- "chain_id": proposed_new_id,
- "sequence_number": proposed_new_seq,
- },
- retcol="event_id",
- allow_none=True,
- )
- if already_allocated:
- # Mark it as already allocated so we don't need to hit
- # the DB again.
- chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
- else:
- new_chain_tuple = (
- proposed_new_id,
- proposed_new_seq,
- )
-
- if not new_chain_tuple:
- new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
- chains_tuples_allocated.add(new_chain_tuple)
-
- chain_map[event_id] = new_chain_tuple
- new_chain_tuples[event_id] = new_chain_tuple
+ # Allocate chain ID/sequence numbers to each new event.
+ new_chain_tuples = cls._allocate_chain_ids(
+ txn,
+ db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ events_to_calc_chain_id_for,
+ chain_map,
+ )
+ chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn(
txn,
@@ -794,6 +752,137 @@ class PersistEventsStore:
],
)
+ @staticmethod
+ def _allocate_chain_ids(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ events_to_calc_chain_id_for: Set[str],
+ chain_map: Dict[str, Tuple[int, int]],
+ ) -> Dict[str, Tuple[int, int]]:
+ """Allocates, but does not persist, chain ID/sequence numbers for the
+ events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+ for info on args)
+ """
+
+ # We now calculate the chain IDs/sequence numbers for the events. We do
+ # this by looking at the chain ID and sequence number of any auth event
+ # with the same type/state_key and incrementing the sequence number by
+ # one. If there was no match or the chain ID/sequence number is already
+ # taken we generate a new chain.
+ #
+ # We try to reduce the number of times that we hit the database by
+ # batching up calls, to make this more efficient when persisting large
+ # numbers of state events (e.g. during joins).
+ #
+ # We do this by:
+ # 1. Calculating for each event which auth event will be used to
+ # inherit the chain ID, i.e. converting the auth chain graph to a
+ # tree that we can allocate chains on. We also keep track of which
+ # existing chain IDs have been referenced.
+ # 2. Fetching the max allocated sequence number for each referenced
+ # existing chain ID, generating a map from chain ID to the max
+ # allocated sequence number.
+ # 3. Iterating over the tree and allocating a chain ID/seq no. to the
+ # new event, by incrementing the sequence number from the
+ # referenced event's chain ID/seq no. and checking that the
+ # incremented sequence number hasn't already been allocated (by
+ # looking in the map generated in the previous step). We generate a
+ # new chain if the sequence number has already been allocated.
+ #
+
+ existing_chains = set() # type: Set[int]
+ tree = [] # type: List[Tuple[str, Optional[str]]]
+
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events before
+ # the event itself.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ for auth_id in event_to_auth_chain.get(event_id, []):
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map.get(auth_id)
+ if existing_chain_id:
+ existing_chains.add(existing_chain_id[0])
+
+ tree.append((event_id, auth_id))
+ break
+ else:
+ tree.append((event_id, None))
+
+ # Fetch the current max sequence number for each existing referenced chain.
+ sql = """
+ SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+ WHERE %s
+ GROUP BY chain_id
+ """
+ clause, args = make_in_list_sql_clause(
+ db_pool.engine, "chain_id", existing_chains
+ )
+ txn.execute(sql % (clause,), args)
+
+ chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+
+ # Allocate the new events chain ID/sequence numbers.
+ #
+ # To reduce the number of calls to the database we don't allocate a
+ # chain ID number in the loop, instead we use a temporary `object()` for
+ # each new chain ID. Once we've done the loop we generate the necessary
+ # number of new chain IDs in one call, replacing all temporary
+ # objects with real allocated chain IDs.
+
+ unallocated_chain_ids = set() # type: Set[object]
+ new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ for event_id, auth_event_id in tree:
+ # If we reference an auth_event_id we fetch the allocated chain ID,
+ # either from the existing `chain_map` or the newly generated
+ # `new_chain_tuples` map.
+ existing_chain_id = None
+ if auth_event_id:
+ existing_chain_id = new_chain_tuples.get(auth_event_id)
+ if not existing_chain_id:
+ existing_chain_id = chain_map[auth_event_id]
+
+ new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+
+ if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ # If we need to start a new chain we allocate a temporary chain ID.
+ if not new_chain_tuple:
+ new_chain_tuple = (object(), 1)
+ unallocated_chain_ids.add(new_chain_tuple[0])
+
+ new_chain_tuples[event_id] = new_chain_tuple
+ chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+ # Generate new chain IDs for all unallocated chain IDs.
+ newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ txn, len(unallocated_chain_ids)
+ )
+
+ # Map from potentially temporary chain ID to real chain ID
+ chain_id_to_allocated_map = dict(
+ zip(unallocated_chain_ids, newly_allocated_chain_ids)
+ ) # type: Dict[Any, int]
+ chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+ return {
+ event_id: (chain_id_to_allocated_map[chain_id], seq)
+ for event_id, (chain_id, seq) in new_chain_tuples.items()
+ }
+
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
@@ -876,7 +965,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -895,7 +984,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@@ -907,7 +996,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -927,7 +1016,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@@ -938,7 +1027,7 @@ class PersistEventsStore:
)
if to_insert:
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1738,7 +1827,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -1767,7 +1856,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@@ -1886,7 +1975,7 @@ class PersistEventsStore:
" )"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1900,7 +1989,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(ev.event_id, ev.room_id)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, update_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, rows_to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+ async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+ """Delete any extra forward extremities for a room.
+
+ Invalidates the "get_latest_event_ids_in_room" cache if any forward
+ extremities were deleted.
+
+ Returns count deleted.
+ """
+
+ def delete_forward_extremities_for_room_txn(txn):
+ # First we need to get the event_id to not delete
+ sql = """
+ SELECT event_id FROM event_forward_extremities
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id,))
+ rows = txn.fetchall()
+ try:
+ event_id = rows[0][0]
+ logger.debug(
+ "Found event_id %s as the forward extremity to keep for room %s",
+ event_id,
+ room_id,
+ )
+ except KeyError:
+ msg = "No forward extremity event found for room %s" % room_id
+ logger.warning(msg)
+ raise SynapseError(400, msg)
+
+ # Now delete the extra forward extremities
+ sql = """
+ DELETE FROM event_forward_extremities
+ WHERE event_id != ? AND room_id = ?
+ """
+
+ txn.execute(sql, (event_id, room_id))
+ logger.info(
+ "Deleted %s extra forward extremities for room %s",
+ txn.rowcount,
+ room_id,
+ )
+
+ if txn.rowcount > 0:
+ # Invalidate the cache
+ self._invalidate_cache_and_stream(
+ txn, self.get_latest_event_ids_in_room, (room_id,),
+ )
+
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction(
+ "delete_forward_extremities_for_room",
+ delete_forward_extremities_for_room_txn,
+ )
+
+ async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ """Get list of forward extremities for a room."""
+
+ def get_forward_extremities_for_room_txn(txn):
+ sql = """
+ SELECT event_id, state_group, depth, received_ts
+ FROM event_forward_extremities
+ INNER JOIN event_to_state_groups USING (event_id)
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ """
+
+ txn.execute(sql, (room_id,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+ )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
sql,
(
(time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
- txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+ txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
+ async def count_daily_e2ee_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+ async def count_daily_sent_e2ee_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_e2ee_messages", _count_messages
+ )
+
+ async def count_daily_active_e2ee_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_active_e2ee_rooms", _count
+ )
+
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
- txn.executemany(
+ txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db_pool.simple_delete_one_txn(
+ # It is expected that there is exactly one pusher to delete, but
+ # if it isn't there (or there are multiple) delete them all.
+ self.db_pool.simple_delete_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e0e57f0578..e4843a202c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -45,7 +45,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
- stream_name="account_data",
+ stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
sequence_name="receipts_sequence",
@@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
- if hs.get_instance_name() in hs.config.worker.writers.events:
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8d05288ed4..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+ """Sets whether a user shadow-banned.
+
+ Args:
+ user: user ID of the user to test
+ shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+ """
+
+ def set_shadow_banned_txn(txn):
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="users",
+ keyvalues={"name": user.to_string()},
+ updatevalues={"shadow_banned": shadow_banned},
+ )
+ # In order for this to apply immediately, clear the cache for this user.
+ tokens = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="access_tokens",
+ keyvalues={"user_id": user.to_string()},
+ retcol="token",
+ )
+ for token in tokens:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (token,)
+ )
+
+ await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@@ -443,6 +472,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+ async def record_user_external_id(
+ self, auth_provider: str, external_id: str, user_id: str
+ ) -> None:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ await self.db_pool.simple_insert(
+ table="user_external_ids",
+ values={
+ "auth_provider": auth_provider,
+ "external_id": external_id,
+ "user_id": user_id,
+ },
+ desc="record_user_external_id",
+ )
+
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@@ -1104,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
+ txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
@@ -1371,26 +1420,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- async def record_user_external_id(
- self, auth_provider: str, external_id: str, user_id: str
- ) -> None:
- """Record a mapping from an external user id to a mxid
-
- Args:
- auth_provider: identifier for the remote auth provider
- external_id: id on that system
- user_id: complete mxid that it is mapped to
- """
- await self.db_pool.simple_insert(
- table="user_external_ids",
- values={
- "auth_provider": auth_provider,
- "external_id": external_id,
- "user_id": user_id,
- },
- desc="record_user_external_id",
- )
-
async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 284f2ce77c..a9fcb5f59c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -16,7 +16,6 @@
import collections
import logging
-import re
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
@@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import MXC_REGEX
logger = logging.getLogger(__name__)
@@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
sql = """
SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id)
@@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
for url in (content_url, thumbnail_url):
if not url:
continue
- matches = mxc_re.match(url)
+ matches = MXC_REGEX.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
- INSERT_CLUMP_SIZE = 1000
-
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
+ txn.execute_batch(to_update_sql, to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users:
- cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+ cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
new file mode 100644
index 0000000000..9f2b5ebc5a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We incorrectly populated these, so we delete them and let the
+-- MultiWriterIdGenerator repopulate it.
+DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data';
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
async def search_rooms(
self,
- room_ids: List[str],
+ room_ids: Collection[str],
search_term: str,
keys: List[str],
limit,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..d421d18f8d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
-from collections import Counter
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
+from typing_extensions import Counter
+
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+ async def get_earliest_token_for_stats(
+ self, stats_type: str, id: str
+ ) -> Optional[int]:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
)
async def bulk_update_stats_delta(
- self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
- ) -> Dict[str, Dict[str, int]]:
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
max_pos,
)
- def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+ def get_changes_room_total_events_and_bytes_txn(
+ self, txn, low_pos: int, high_pos: int
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Gets the total_events and total_event_bytes counts for rooms and
senders, in a range of stream_orderings (including backfilled events).
Args:
txn
- low_pos (int): Low stream ordering
- high_pos (int): High stream ordering
+ low_pos: Low stream ordering
+ high_pos: High stream ordering
Returns:
- tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
- room and user deltas for total_events/total_event_bytes in the
+ The room and user deltas for total_events/total_event_bytes in the
format of `stats_id` -> fields
"""
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..7b9729da09 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: int) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 39a3ab1162..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
@@ -261,7 +266,11 @@ class MultiWriterIdGenerator:
# We check that the table and sequence haven't diverged.
for table, _, id_column in tables:
self._sequence_gen.check_consistency(
- db_conn, table=table, id_column=id_column, positive=positive
+ db_conn,
+ table=table,
+ id_column=id_column,
+ stream_name=stream_name,
+ positive=positive,
)
# This goes and fills out the above state from the database.
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 412df6b8ef..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -45,6 +45,21 @@ and run the following SQL:
See docs/postgres.md for more information.
"""
+_INCONSISTENT_STREAM_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated stream position
+of '%(stream_name)s' in the 'stream_positions' table.
+
+This is likely a programming error and should be reported at
+https://github.com/matrix-org/synapse.
+
+A temporary workaround to fix this error is to shut down Synapse (including
+any and all workers) and run the following SQL:
+
+ DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
+
+This will need to be done every time the server is restarted.
+"""
+
class SequenceGenerator(metaclass=abc.ABCMeta):
"""A class which generates a unique sequence of integers"""
@@ -55,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
- This is to handle various cases where the sequence value can get out
- of sync with the table, e.g. if Synapse gets rolled back to a previous
+ This is to handle various cases where the sequence value can get out of
+ sync with the table, e.g. if Synapse gets rolled back to a previous
version and the rolled forwards again.
+
+ If a stream name is given then this will check that any value in the
+ `stream_positions` table is less than or equal to the current sequence
+ value. If it isn't then it's likely that streams have been crossed
+ somewhere (e.g. two ID generators have the same stream name).
"""
...
@@ -93,8 +119,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
+ """See SequenceGenerator.check_consistency for docstring.
+ """
+
txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
@@ -118,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator):
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
last_value, is_called = txn.fetchone()
+
+ # If we have an associated stream check the stream_positions table.
+ max_in_stream_positions = None
+ if stream_name:
+ txn.execute(
+ "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
+ (stream_name,),
+ )
+ row = txn.fetchone()
+ if row:
+ max_in_stream_positions = row[0]
+
txn.close()
# If `is_called` is False then `last_value` is actually the value that
@@ -138,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
% {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
)
+ # If we have values in the stream positions table then they have to be
+ # less than or equal to `last_value`
+ if max_in_stream_positions and max_in_stream_positions > last_value:
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_STREAM_ERROR
+ % {"seq": self._sequence_name, "stream_name": stream_name}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -174,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: Connection,
+ table: str,
+ id_column: str,
+ stream_name: Optional[str] = None,
+ positive: bool = True,
):
# There is nothing to do for in memory sequences
pass
|