summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py196
1 files changed, 136 insertions, 60 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py

index 0231f9407b..56e8eb16a8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import ( cast, ) +from immutabledict import immutabledict + from synapse.api.constants import EduTypes from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, JsonMapping +from synapse.types import ( + JsonDict, + JsonMapping, + MultiWriterStreamToken, + PersistedPosition, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "receipts_linearized", entity_column="room_id", stream_column="stream_id", - max_value=max_receipts_stream_id, + max_value=max_receipts_stream_id.stream, limit=10000, ) self._receipts_stream_cache = StreamChangeCache( @@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - def get_max_receipt_stream_id(self) -> int: + def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: """Get the current max stream ID for receipts stream""" - return self._receipts_id_gen.get_current_token() + + min_pos = self._receipts_id_gen.get_current_token() + + positions = {} + if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._receipts_id_gen.get_positions().items() + if p > min_pos + } + + return MultiWriterStreamToken( + stream=min_pos, instance_map=immutabledict(positions) + ) + + def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: + return self._receipts_id_gen.get_current_token_for_writer(instance_name) def get_last_unthreaded_receipt_for_user_txn( self, @@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore): } async def get_linearized_receipts_for_rooms( - self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Iterable[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. @@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( - room_ids, from_key + room_ids, from_key.stream ) results = await self._get_linearized_receipts_for_rooms( @@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. @@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore): if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. - if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + if not self._receipts_stream_cache.has_entity_changed( + room_id, from_key.stream + ): return [] return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cached(tree=True) async def _get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" - def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: if from_key: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id = ? AND stream_id > ? AND stream_id <= ?" - ) + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE room_id = ? AND stream_id > ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, from_key, to_key)) - else: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id = ? AND stream_id <= ?" + txn.execute( + sql, (room_id, from_key.stream, to_key.get_max_stream_pos()) ) + else: + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE + room_id = ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, to_key)) - - rows = self.db_pool.cursor_to_dict(txn) + txn.execute(sql, (room_id, to_key.get_max_stream_pos())) - return rows + return [ + (receipt_type, user_id, event_id, data) + for stream_id, instance_name, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) @@ -339,10 +387,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return [] content: JsonDict = {} - for row in rows: - content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ - row["user_id"] - ] = db_to_json(row["data"]) + for receipt_type, user_id, event_id, data in rows: + content.setdefault(event_id, {}).setdefault(receipt_type, {})[ + user_id + ] = db_to_json(data) return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @@ -352,25 +400,37 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=3, ) async def _get_linearized_receipts_for_rooms( - self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Collection[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} - def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def f( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: if from_key: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data + FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause( self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [from_key, to_key] + list(args)) + txn.execute( + sql + clause, + [from_key.stream, to_key.get_max_stream_pos()] + list(args), + ) else: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data + FROM receipts_linearized WHERE stream_id <= ? AND """ @@ -378,31 +438,37 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [to_key] + list(args)) + txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return self.db_pool.cursor_to_dict(txn) + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) results: JsonDict = {} - for row in txn_results: + for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( - row["room_id"], - {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, + room_id, + {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(row["event_id"], {}) - receipt_type = event_entry.setdefault(row["receipt_type"], {}) + event_entry = room_event["content"].setdefault(event_id, {}) + receipt_type_dict = event_entry.setdefault(receipt_type, {}) - receipt_type[row["user_id"]] = db_to_json(row["data"]) - if row["thread_id"]: - receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] + receipt_type_dict[user_id] = db_to_json(data) + if thread_id: + receipt_type_dict[user_id]["thread_id"] = thread_id results = { room_id: [results[room_id]] if room_id in results else [] @@ -414,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=2, ) async def get_linearized_receipts_for_all_rooms( - self, to_key: int, from_key: Optional[int] = None + self, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. @@ -428,46 +496,54 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: if from_key: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [from_key, to_key]) + txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) else: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [to_key]) + txn.execute(sql, [to_key.get_max_stream_pos()]) - return self.db_pool.cursor_to_dict(txn) + return [ + (room_id, receipt_type, user_id, event_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f ) results: JsonDict = {} - for row in txn_results: + for room_id, receipt_type, user_id, event_id, data in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( - row["room_id"], - {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, + room_id, + {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(row["event_id"], {}) - receipt_type = event_entry.setdefault(row["receipt_type"], {}) + event_entry = room_event["content"].setdefault(event_id, {}) + receipt_type_dict = event_entry.setdefault(receipt_type, {}) - receipt_type[row["user_id"]] = db_to_json(row["data"]) + receipt_type_dict[user_id] = db_to_json(data) return results @@ -537,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) updates = cast( List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], @@ -687,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore): keyvalues=keyvalues, values={ "stream_id": stream_id, + "instance_name": self._instance_name, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), @@ -742,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_ids: List[str], thread_id: Optional[str], data: dict, - ) -> Optional[Tuple[int, int]]: + ) -> Optional[PersistedPosition]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -804,9 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id + return PersistedPosition(self._instance_name, stream_id) async def _insert_graph_receipt( self,