diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 5b31aab700..97aed1500e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,13 +15,14 @@
import logging
from collections import namedtuple
-from typing import Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -47,7 +48,7 @@ class TransactionStore(SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TransactionStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore):
allow_none=True,
)
- if result and result["retry_last_ts"] > 0:
+ # check we have a row and retry_last_ts is not null or zero
+ # (retry_last_ts can't be negative)
+ if result and result["retry_last_ts"]:
return result
else:
return None
@@ -215,6 +218,7 @@ class TransactionStore(SQLBaseStore):
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
@@ -246,7 +250,11 @@ class TransactionStore(SQLBaseStore):
"retry_interval": retry_interval,
},
)
- elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+ elif (
+ retry_interval == 0
+ or prev_row["retry_interval"] is None
+ or prev_row["retry_interval"] < retry_interval
+ ):
self.db_pool.simple_update_one_txn(
txn,
"destinations",
@@ -273,3 +281,196 @@ class TransactionStore(SQLBaseStore):
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
+
+ async def store_destination_rooms_entries(
+ self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ ) -> None:
+ """
+ Updates or creates `destination_rooms` entries in batch for a single event.
+
+ Args:
+ destinations: list of destinations
+ room_id: the room_id of the event
+ stream_ordering: the stream_ordering of the event
+ """
+
+ return await self.db_pool.runInteraction(
+ "store_destination_rooms_entries",
+ self._store_destination_rooms_entries_txn,
+ destinations,
+ room_id,
+ stream_ordering,
+ )
+
+ def _store_destination_rooms_entries_txn(
+ self,
+ txn: LoggingTransaction,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
+ ) -> None:
+
+ # ensure we have a `destinations` row for this destination, as there is
+ # a foreign key constraint.
+ if isinstance(self.database_engine, PostgresEngine):
+ q = """
+ INSERT INTO destinations (destination)
+ VALUES (?)
+ ON CONFLICT DO NOTHING;
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ q = """
+ INSERT OR IGNORE INTO destinations (destination)
+ VALUES (?);
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ txn.execute_batch(q, ((destination,) for destination in destinations))
+
+ rows = [(destination, room_id) for destination in destinations]
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ "destination_rooms",
+ ["destination", "room_id"],
+ rows,
+ ["stream_ordering"],
+ [(stream_ordering,)] * len(rows),
+ )
+
+ async def get_destination_last_successful_stream_ordering(
+ self, destination: str
+ ) -> Optional[int]:
+ """
+ Gets the stream ordering of the PDU most-recently successfully sent
+ to the specified destination, or None if this information has not been
+ tracked yet.
+
+ Args:
+ destination: the destination to query
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "destinations",
+ {"destination": destination},
+ "last_successful_stream_ordering",
+ allow_none=True,
+ desc="get_last_successful_stream_ordering",
+ )
+
+ async def set_destination_last_successful_stream_ordering(
+ self, destination: str, last_successful_stream_ordering: int
+ ) -> None:
+ """
+ Marks that we have successfully sent the PDUs up to and including the
+ one specified.
+
+ Args:
+ destination: the destination we have successfully sent to
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ return await self.db_pool.simple_upsert(
+ "destinations",
+ keyvalues={"destination": destination},
+ values={"last_successful_stream_ordering": last_successful_stream_ordering},
+ desc="set_last_successful_stream_ordering",
+ )
+
+ async def get_catch_up_room_event_ids(
+ self, destination: str, last_successful_stream_ordering: int,
+ ) -> List[str]:
+ """
+ Returns at most 50 event IDs and their corresponding stream_orderings
+ that correspond to the oldest events that have not yet been sent to
+ the destination.
+
+ Args:
+ destination: the destination in question
+ last_successful_stream_ordering: the stream_ordering of the
+ most-recently successfully-transmitted event to the destination
+
+ Returns:
+ list of event_ids
+ """
+ return await self.db_pool.runInteraction(
+ "get_catch_up_room_event_ids",
+ self._get_catch_up_room_event_ids_txn,
+ destination,
+ last_successful_stream_ordering,
+ )
+
+ @staticmethod
+ def _get_catch_up_room_event_ids_txn(
+ txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+ ) -> List[str]:
+ q = """
+ SELECT event_id FROM destination_rooms
+ JOIN events USING (stream_ordering)
+ WHERE destination = ?
+ AND stream_ordering > ?
+ ORDER BY stream_ordering
+ LIMIT 50
+ """
+ txn.execute(
+ q, (destination, last_successful_stream_ordering),
+ )
+ event_ids = [row[0] for row in txn]
+ return event_ids
+
+ async def get_catch_up_outstanding_destinations(
+ self, after_destination: Optional[str]
+ ) -> List[str]:
+ """
+ Gets at most 25 destinations which have outstanding PDUs to be caught up,
+ and are not being backed off from
+ Args:
+ after_destination:
+ If provided, all destinations must be lexicographically greater
+ than this one.
+
+ Returns:
+ list of up to 25 destinations with outstanding catch-up.
+ These are the lexicographically first destinations which are
+ lexicographically greater than after_destination (if provided).
+ """
+ time = self.hs.get_clock().time_msec()
+
+ return await self.db_pool.runInteraction(
+ "get_catch_up_outstanding_destinations",
+ self._get_catch_up_outstanding_destinations_txn,
+ time,
+ after_destination,
+ )
+
+ @staticmethod
+ def _get_catch_up_outstanding_destinations_txn(
+ txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
+ ) -> List[str]:
+ q = """
+ SELECT destination FROM destinations
+ WHERE destination IN (
+ SELECT destination FROM destination_rooms
+ WHERE destination_rooms.stream_ordering >
+ destinations.last_successful_stream_ordering
+ )
+ AND destination > ?
+ AND (
+ retry_last_ts IS NULL OR
+ retry_last_ts + retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
+ """
+ txn.execute(
+ q,
+ (
+ # everything is lexicographically greater than "" so this gives
+ # us the first batch of up to 25.
+ after_destination or "",
+ now_time_ms,
+ ),
+ )
+
+ destinations = [row[0] for row in txn]
+ return destinations
|