diff options
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 148 |
1 files changed, 109 insertions, 39 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b2645ab43c..56e8eb16a8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -28,6 +28,8 @@ from typing import ( cast, ) +from immutabledict import immutabledict + from synapse.api.constants import EduTypes from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, JsonMapping +from synapse.types import ( + JsonDict, + JsonMapping, + MultiWriterStreamToken, + PersistedPosition, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "receipts_linearized", entity_column="room_id", stream_column="stream_id", - max_value=max_receipts_stream_id, + max_value=max_receipts_stream_id.stream, limit=10000, ) self._receipts_stream_cache = StreamChangeCache( @@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - def get_max_receipt_stream_id(self) -> int: + def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: """Get the current max stream ID for receipts stream""" - return self._receipts_id_gen.get_current_token() + + min_pos = self._receipts_id_gen.get_current_token() + + positions = {} + if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._receipts_id_gen.get_positions().items() + if p > min_pos + } + + return MultiWriterStreamToken( + stream=min_pos, instance_map=immutabledict(positions) + ) + + def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: + return self._receipts_id_gen.get_current_token_for_writer(instance_name) def get_last_unthreaded_receipt_for_user_txn( self, @@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore): } async def get_linearized_receipts_for_rooms( - self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Iterable[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. @@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( - room_ids, from_key + room_ids, from_key.stream ) results = await self._get_linearized_receipts_for_rooms( @@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. @@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore): if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. - if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + if not self._receipts_stream_cache.has_entity_changed( + room_id, from_key.stream + ): return [] return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cached(tree=True) async def _get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: if from_key: - sql = ( - "SELECT receipt_type, user_id, event_id, data" - " FROM receipts_linearized WHERE" - " room_id = ? AND stream_id > ? AND stream_id <= ?" - ) + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE room_id = ? AND stream_id > ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, from_key, to_key)) - else: - sql = ( - "SELECT receipt_type, user_id, event_id, data" - " FROM receipts_linearized WHERE" - " room_id = ? AND stream_id <= ?" + txn.execute( + sql, (room_id, from_key.stream, to_key.get_max_stream_pos()) ) + else: + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE + room_id = ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, to_key)) + txn.execute(sql, (room_id, to_key.get_max_stream_pos())) - return cast(List[Tuple[str, str, str, str]], txn.fetchall()) + return [ + (receipt_type, user_id, event_id, data) + for stream_id, instance_name, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) @@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=3, ) async def _get_linearized_receipts_for_rooms( - self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Collection[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} @@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> List[Tuple[str, str, str, str, Optional[str], str]]: if from_key: sql = """ - SELECT room_id, receipt_type, user_id, event_id, thread_id, data + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ @@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [from_key, to_key] + list(args)) + txn.execute( + sql + clause, + [from_key.stream, to_key.get_max_stream_pos()] + list(args), + ) else: sql = """ - SELECT room_id, receipt_type, user_id, event_id, thread_id, data + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data FROM receipts_linearized WHERE stream_id <= ? AND """ @@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [to_key] + list(args)) + txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return cast( - List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall() - ) + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f @@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=2, ) async def get_linearized_receipts_for_all_rooms( - self, to_key: int, from_key: Optional[int] = None + self, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. @@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore): def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: if from_key: sql = """ - SELECT room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [from_key, to_key]) + txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) else: sql = """ - SELECT room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [to_key]) + txn.execute(sql, [to_key.get_max_stream_pos()]) - return cast(List[Tuple[str, str, str, str, str]], txn.fetchall()) + return [ + (room_id, receipt_type, user_id, event_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f @@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) updates = cast( List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], @@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore): keyvalues=keyvalues, values={ "stream_id": stream_id, + "instance_name": self._instance_name, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), @@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_ids: List[str], thread_id: Optional[str], data: dict, - ) -> Optional[int]: + ) -> Optional[PersistedPosition]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - return stream_id + return PersistedPosition(self._instance_name, stream_id) async def _insert_graph_receipt( self, |