diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index e90592855a..a6e91956af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -14,6 +14,7 @@
from typing import Optional
from unittest.mock import Mock
+from parameterized import parameterized_class
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
@@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
+@parameterized_class(
+ [
+ {"enable_room_poke_code_path": False},
+ {"enable_room_poke_code_path": True},
+ ]
+)
class FederationSenderDevicesTestCases(HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self):
c = super().default_config()
c["send_federation"] = True
+ c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
return c
def prepare(self, reactor, clock, hs):
- # stub out get_users_who_share_room_with_user so that it claims that
- # `@user2:host2` is in the room
- def get_users_who_share_room_with_user(user_id):
+ # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+ # server thinks the user shares a room with `@user2:host2`
+ def get_rooms_for_user(user_id):
+ return defer.succeed({"!room:host1"})
+
+ hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+
+ def get_users_in_room(room_id):
return defer.succeed({"@user2:host2"})
- hs.get_datastores().main.get_users_who_share_room_with_user = (
- get_users_who_share_room_with_user
- )
+ hs.get_datastores().main.get_users_in_room = get_users_in_room
# whenever send_transaction is called, record the edu data
self.edus = []
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add two device updates with sequential `stream_id`s
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get all device updates ever meant for this remote
@@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
"device_id5",
]
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get device updates meant for this remote
@@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add some more device updates to ensure it still resumes properly
device_ids = ["device_id6", "device_id7"]
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get the next batch of device updates
@@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
self.get_success(
self.store.add_device_change_to_streams(
- "@user_id:test", device_ids, ["somehost"]
+ "@user_id:test", device_ids, ["somehost"], ["!some:room"]
)
)
|