summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_sender.py23
-rw-r--r--tests/storage/test_devices.py14
2 files changed, 27 insertions, 10 deletions
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index e90592855a..a6e91956af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -14,6 +14,7 @@
 from typing import Optional
 from unittest.mock import Mock
 
+from parameterized import parameterized_class
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
 
@@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         )
 
 
+@parameterized_class(
+    [
+        {"enable_room_poke_code_path": False},
+        {"enable_room_poke_code_path": True},
+    ]
+)
 class FederationSenderDevicesTestCases(HomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
     def default_config(self):
         c = super().default_config()
         c["send_federation"] = True
+        c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out get_users_who_share_room_with_user so that it claims that
-        # `@user2:host2` is in the room
-        def get_users_who_share_room_with_user(user_id):
+        # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+        # server thinks the user shares a room with `@user2:host2`
+        def get_rooms_for_user(user_id):
+            return defer.succeed({"!room:host1"})
+
+        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+
+        def get_users_in_room(room_id):
             return defer.succeed({"@user2:host2"})
 
-        hs.get_datastores().main.get_users_who_share_room_with_user = (
-            get_users_who_share_room_with_user
-        )
+        hs.get_datastores().main.get_users_in_room = get_users_in_room
 
         # whenever send_transaction is called, record the edu data
         self.edus = []
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         # Add two device updates with sequential `stream_id`s
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get all device updates ever meant for this remote
@@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
             "device_id5",
         ]
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get device updates meant for this remote
@@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
         # Add some more device updates to ensure it still resumes properly
         device_ids = ["device_id6", "device_id7"]
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get the next batch of device updates
@@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         self.get_success(
             self.store.add_device_change_to_streams(
-                "@user_id:test", device_ids, ["somehost"]
+                "@user_id:test", device_ids, ["somehost"], ["!some:room"]
             )
         )