summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_receipts.py55
1 files changed, 38 insertions, 17 deletions
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index b1a8f8bba7..191c957fb5 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from parameterized import parameterized
+
 from synapse.api.constants import ReceiptTypes
 from synapse.types import UserID, create_requester
 
@@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test"
 
 
 class ReceiptTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor, clock, homeserver) -> None:
         super().prepare(reactor, clock, homeserver)
 
         self.store = homeserver.get_datastores().main
@@ -83,10 +85,15 @@ class ReceiptTestCase(HomeserverTestCase):
             )
         )
 
-    def test_return_empty_with_no_data(self):
+    def test_return_empty_with_no_data(self) -> None:
         res = self.get_success(
             self.store.get_receipts_for_user(
-                OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+                OUR_USER_ID,
+                [
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ],
             )
         )
         self.assertEqual(res, {})
@@ -94,7 +101,11 @@ class ReceiptTestCase(HomeserverTestCase):
         res = self.get_success(
             self.store.get_receipts_for_user_with_orderings(
                 OUR_USER_ID,
-                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+                [
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ],
             )
         )
         self.assertEqual(res, {})
@@ -103,12 +114,19 @@ class ReceiptTestCase(HomeserverTestCase):
             self.store.get_last_receipt_event_id_for_user(
                 OUR_USER_ID,
                 self.room_id1,
-                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+                [
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ],
             )
         )
         self.assertEqual(res, None)
 
-    def test_get_receipts_for_user(self):
+    @parameterized.expand(
+        [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
+    )
+    def test_get_receipts_for_user(self, receipt_type: str) -> None:
         # Send some events into the first room
         event1_1_id = self.create_and_send_event(
             self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -126,14 +144,14 @@ class ReceiptTestCase(HomeserverTestCase):
         # Send private read receipt for the second event
         self.get_success(
             self.store.insert_receipt(
-                self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+                self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
             )
         )
 
         # Test we get the latest event when we want both private and public receipts
         res = self.get_success(
             self.store.get_receipts_for_user(
-                OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+                OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
             )
         )
         self.assertEqual(res, {self.room_id1: event1_2_id})
@@ -146,7 +164,7 @@ class ReceiptTestCase(HomeserverTestCase):
 
         # Test we get the latest event when we want only the public receipt
         res = self.get_success(
-            self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE])
+            self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type])
         )
         self.assertEqual(res, {self.room_id1: event1_2_id})
 
@@ -169,17 +187,20 @@ class ReceiptTestCase(HomeserverTestCase):
         # Test new room is reflected in what the method returns
         self.get_success(
             self.store.insert_receipt(
-                self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+                self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
             )
         )
         res = self.get_success(
             self.store.get_receipts_for_user(
-                OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+                OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
             )
         )
         self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
 
-    def test_get_last_receipt_event_id_for_user(self):
+    @parameterized.expand(
+        [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
+    )
+    def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None:
         # Send some events into the first room
         event1_1_id = self.create_and_send_event(
             self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -197,7 +218,7 @@ class ReceiptTestCase(HomeserverTestCase):
         # Send private read receipt for the second event
         self.get_success(
             self.store.insert_receipt(
-                self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+                self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
             )
         )
 
@@ -206,7 +227,7 @@ class ReceiptTestCase(HomeserverTestCase):
             self.store.get_last_receipt_event_id_for_user(
                 OUR_USER_ID,
                 self.room_id1,
-                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+                [ReceiptTypes.READ, receipt_type],
             )
         )
         self.assertEqual(res, event1_2_id)
@@ -222,7 +243,7 @@ class ReceiptTestCase(HomeserverTestCase):
         # Test we get the latest event when we want only the private receipt
         res = self.get_success(
             self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
+                OUR_USER_ID, self.room_id1, [receipt_type]
             )
         )
         self.assertEqual(res, event1_2_id)
@@ -248,14 +269,14 @@ class ReceiptTestCase(HomeserverTestCase):
         # Test new room is reflected in what the method returns
         self.get_success(
             self.store.insert_receipt(
-                self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+                self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
             )
         )
         res = self.get_success(
             self.store.get_last_receipt_event_id_for_user(
                 OUR_USER_ID,
                 self.room_id2,
-                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+                [ReceiptTypes.READ, receipt_type],
             )
         )
         self.assertEqual(res, event2_1_id)