diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0231f9407b..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from immutabledict import immutabledict
+
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ MultiWriterStreamToken,
+ PersistedPosition,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
- max_value=max_receipts_stream_id,
+ max_value=max_receipts_stream_id.stream,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
- def get_max_receipt_stream_id(self) -> int:
+ def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
- return self._receipts_id_gen.get_current_token()
+
+ min_pos = self._receipts_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._receipts_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return MultiWriterStreamToken(
+ stream=min_pos, instance_map=immutabledict(positions)
+ )
+
+ def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+ return self._receipts_id_gen.get_current_token_for_writer(instance_name)
def get_last_unthreaded_receipt_for_user_txn(
self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Iterable[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
- room_ids, from_key
+ room_ids, from_key.stream
)
results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
- if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ if not self._receipts_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ ):
return []
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(tree=True)
async def _get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id > ? AND stream_id <= ?"
- )
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, from_key, to_key))
- else:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id <= ?"
+ txn.execute(
+ sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
)
+ else:
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
+ room_id = ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, to_key))
-
- rows = self.db_pool.cursor_to_dict(txn)
+ txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
- return rows
+ return [
+ (receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -339,10 +387,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []
content: JsonDict = {}
- for row in rows:
- content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
- row["user_id"]
- ] = db_to_json(row["data"])
+ for receipt_type, user_id, event_id, data in rows:
+ content.setdefault(event_id, {}).setdefault(receipt_type, {})[
+ user_id
+ ] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@@ -352,25 +400,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3,
)
async def _get_linearized_receipts_for_rooms(
- self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Collection[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [from_key, to_key] + list(args))
+ txn.execute(
+ sql + clause,
+ [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+ )
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -378,31 +438,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [to_key] + list(args))
+ txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return self.db_pool.cursor_to_dict(txn)
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
- if row["thread_id"]:
- receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
+ receipt_type_dict[user_id] = db_to_json(data)
+ if thread_id:
+ receipt_type_dict[user_id]["thread_id"] = thread_id
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -414,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
- self, to_key: int, from_key: Optional[int] = None
+ self,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -428,46 +496,54 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [from_key, to_key])
+ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [to_key])
+ txn.execute(sql, [to_key.get_max_stream_pos()])
- return self.db_pool.cursor_to_dict(txn)
+ return [
+ (room_id, receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
+ receipt_type_dict[user_id] = db_to_json(data)
return results
@@ -537,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -687,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues,
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
@@ -742,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[Tuple[int, int]]:
+ ) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -804,9 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
+ return PersistedPosition(self._instance_name, stream_id)
async def _insert_graph_receipt(
self,
|