diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index b1a8f8bba7..81253d0361 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, Optional
+
from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
@@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor, clock, homeserver) -> None:
super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main
@@ -83,10 +85,41 @@ class ReceiptTestCase(HomeserverTestCase):
)
)
- def test_return_empty_with_no_data(self):
+ def get_last_unthreaded_receipt(
+ self, receipt_types: Collection[str], room_id: Optional[str] = None
+ ) -> Optional[str]:
+ """
+ Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
+
+ Args:
+ receipt_types: The receipt types to fetch.
+
+ Returns:
+ The latest receipt, if one exists.
+ """
+ result = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get_last_receipt_event_id_for_user",
+ self.store.get_last_unthreaded_receipt_for_user_txn,
+ OUR_USER_ID,
+ room_id or self.room_id1,
+ receipt_types,
+ )
+ )
+ if not result:
+ return None
+
+ event_id, _ = result
+ return event_id
+
+ def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
self.store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ OUR_USER_ID,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, {})
@@ -94,21 +127,21 @@ class ReceiptTestCase(HomeserverTestCase):
res = self.get_success(
self.store.get_receipts_for_user_with_orderings(
OUR_USER_ID,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, {})
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
+
self.assertEqual(res, None)
- def test_get_receipts_for_user(self):
+ def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -120,13 +153,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event1_2_id],
+ None,
+ {},
)
)
@@ -153,7 +191,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
@@ -169,7 +207,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+ self.room_id2,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event2_1_id],
+ None,
+ {},
)
)
res = self.get_success(
@@ -179,7 +222,7 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
- def test_get_last_receipt_event_id_for_user(self):
+ def test_get_last_receipt_event_id_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -191,53 +234,42 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event1_2_id],
+ None,
+ {},
)
)
# Test we get the latest event when we want both private and public receipts
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, event1_2_id)
# Test we get the older event when we want only public receipt
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_1_id)
# Test we get the latest event when we want only the private receipt
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
self.assertEqual(res, event1_2_id)
# Test receipt updating
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
- )
- )
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_2_id)
# Send some events into the second room
@@ -248,14 +280,15 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
- )
- )
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
self.room_id2,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event2_1_id],
+ None,
+ {},
)
)
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
+ )
self.assertEqual(res, event2_1_id)
|