diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index f74aa1e3f3..21e954ccc1 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
- def insert_linearized_receipt_txn(
+ def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
@@ -686,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rx_ts
+ def _graph_to_linear(
+ self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
+ ) -> str:
+ """
+ Generate a linearized event from a list of events (i.e. a list of forward
+ extremities in the room).
+
+ This should allow for calculation of the correct read receipt even if
+ servers have different event ordering.
+
+ Args:
+ txn: The transaction
+ room_id: The room ID the events are in.
+ event_ids: The list of event IDs to linearize.
+
+ Returns:
+ The linearized event ID.
+ """
+ # TODO: Make this better.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", event_ids
+ )
+
+ sql = """
+ SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+ SELECT max(stream_ordering) WHERE %s
+ )
+ """ % (
+ clause,
+ )
+
+ txn.execute(sql, [room_id] + list(args))
+ rows = txn.fetchall()
+ if rows:
+ return rows[0][0]
+ else:
+ raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
async def insert_receipt(
self,
room_id: str,
@@ -712,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
- # TODO: Make this better.
- def graph_to_linear(txn: LoggingTransaction) -> str:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "event_id", event_ids
- )
-
- sql = """
- SELECT event_id WHERE room_id = ? AND stream_ordering IN (
- SELECT max(stream_ordering) WHERE %s
- )
- """ % (
- clause,
- )
-
- txn.execute(sql, [room_id] + list(args))
- rows = txn.fetchall()
- if rows:
- return rows[0][0]
- else:
- raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
-
linearized_event_id = await self.db_pool.runInteraction(
- "insert_receipt_conv", graph_to_linear
+ "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
- self.insert_linearized_receipt_txn,
+ self._insert_linearized_receipt_txn,
room_id,
receipt_type,
user_id,
@@ -761,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
now - event_ts,
)
- await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
-
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
-
- async def insert_graph_receipt(
- self,
- room_id: str,
- receipt_type: str,
- user_id: str,
- event_ids: List[str],
- data: JsonDict,
- ) -> None:
- assert self._can_write_to_receipts
-
await self.db_pool.runInteraction(
"insert_graph_receipt",
- self.insert_graph_receipt_txn,
+ self._insert_graph_receipt_txn,
room_id,
receipt_type,
user_id,
@@ -787,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- def insert_graph_receipt_txn(
+ max_persisted_id = self._receipts_id_gen.get_current_token()
+
+ return stream_id, max_persisted_id
+
+ def _insert_graph_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
|