diff --git a/changelog.d/13937.feature b/changelog.d/13937.feature
new file mode 100644
index 0000000000..d0cb902dff
--- /dev/null
+++ b/changelog.d/13937.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index f4cdc2e399..7e0ffef7d3 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
) -> NotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
- result = self.get_last_receipt_for_user_txn(
+ result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
user_id,
room_id,
- receipt_types=(
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ),
+ receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
if result:
@@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ),
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
sql = f"""
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 52fe0db924..246f78ac1f 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
- async def get_last_receipt_event_id_for_user(
- self, user_id: str, room_id: str, receipt_types: Collection[str]
- ) -> Optional[str]:
- """
- Fetch the event ID for the latest receipt in a room with one of the given receipt types.
-
- Args:
- user_id: The user to fetch receipts for.
- room_id: The room ID to fetch the receipt for.
- receipt_type: The receipt types to fetch.
-
- Returns:
- The latest receipt, if one exists.
- """
- result = await self.db_pool.runInteraction(
- "get_last_receipt_event_id_for_user",
- self.get_last_receipt_for_user_txn,
- user_id,
- room_id,
- receipt_types,
- )
- if not result:
- return None
-
- event_id, _ = result
- return event_id
-
- def get_last_receipt_for_user_txn(
+ def get_last_unthreaded_receipt_for_user_txn(
self,
txn: LoggingTransaction,
user_id: str,
@@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_types: Collection[str],
) -> Optional[Tuple[str, int]]:
"""
- Fetch the event ID and stream_ordering for the latest receipt in a room
- with one of the given receipt types.
+ Fetch the event ID and stream_ordering for the latest unthreaded receipt
+ in a room with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
- receipt_type: The receipt types to fetch.
+ receipt_types: The receipt types to fetch.
Returns:
The event ID and stream ordering of the latest receipt, if one exists.
@@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
WHERE {clause}
AND user_id = ?
AND room_id = ?
+ AND thread_id IS NULL
ORDER BY stream_ordering DESC
LIMIT 1
"""
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 9459ee1705..81253d0361 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, Optional
from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
@@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase):
)
)
+ def get_last_unthreaded_receipt(
+ self, receipt_types: Collection[str], room_id: Optional[str] = None
+ ) -> Optional[str]:
+ """
+ Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
+
+ Args:
+ receipt_types: The receipt types to fetch.
+
+ Returns:
+ The latest receipt, if one exists.
+ """
+ result = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get_last_receipt_event_id_for_user",
+ self.store.get_last_unthreaded_receipt_for_user_txn,
+ OUR_USER_ID,
+ room_id or self.room_id1,
+ receipt_types,
+ )
+ )
+ if not result:
+ return None
+
+ event_id, _ = result
+ return event_id
+
def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
self.store.get_receipts_for_user(
@@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {})
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
+
self.assertEqual(res, None)
def test_get_receipts_for_user(self) -> None:
@@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase):
)
# Test we get the latest event when we want both private and public receipts
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, event1_2_id)
# Test we get the older event when we want only public receipt
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_1_id)
# Test we get the latest event when we want only the private receipt
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
self.assertEqual(res, event1_2_id)
# Test receipt updating
@@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase):
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_2_id)
# Send some events into the second room
@@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase):
{},
)
)
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id2,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
)
self.assertEqual(res, event2_1_id)
|