diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index f4b5fb3328..93b9fad012 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.resource import ReplicationResource
-from synapse.types import Requester, UserID
+import contextlib
+import json
+from mock import Mock, NonCallableMock
from twisted.internet import defer
+
+import synapse.types
+from synapse.replication.resource import ReplicationResource
+from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver, requester_for_user
-from mock import Mock, NonCallableMock
-import json
-import contextlib
+from tests.utils import setup_test_homeserver
class ReplicationResourceCase(unittest.TestCase):
@@ -61,18 +63,18 @@ class ReplicationResourceCase(unittest.TestCase):
def test_events(self):
get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
- Requester(self.user, "", False), {}
+ synapse.types.create_requester(self.user), {}
)
code, body = yield get
self.assertEquals(code, 200)
self.assertEquals(body["events"]["field_names"], [
- "position", "internal", "json"
+ "position", "internal", "json", "state_group"
])
@defer.inlineCallbacks
def test_presence(self):
get = self.get(presence="-1")
- yield self.hs.get_handlers().presence_handler.set_state(
+ yield self.hs.get_presence_handler().set_state(
self.user, {"presence": "online"}
)
code, body = yield get
@@ -87,7 +89,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_typing(self):
room_id = yield self.create_room()
get = self.get(typing="-1")
- yield self.hs.get_handlers().typing_notification_handler.started_typing(
+ yield self.hs.get_typing_handler().started_typing(
self.user, self.user, room_id, timeout=2
)
code, body = yield get
@@ -101,7 +103,7 @@ class ReplicationResourceCase(unittest.TestCase):
room_id = yield self.create_room()
event_id = yield self.send_text_message(room_id, "Hello, World")
get = self.get(receipts="-1")
- yield self.hs.get_handlers().receipts_handler.received_client_receipt(
+ yield self.hs.get_receipts_handler().received_client_receipt(
room_id, "m.read", self.user_id, event_id
)
code, body = yield get
@@ -118,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase):
self.hs.clock.advance_time_msec(1)
code, body = yield get
self.assertEquals(code, 200)
- self.assertEquals(body, {})
+ self.assertEquals(body.get("rows", []), [])
test_timeout.__name__ = "test_timeout_%s" % (stream)
return test_timeout
@@ -132,12 +134,13 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules")
test_timeout_pushers = _test_timeout("pushers")
+ test_timeout_state = _test_timeout("state")
@defer.inlineCallbacks
def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event(
- requester_for_user(self.user),
+ synapse.types.create_requester(self.user),
{
"type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"},
@@ -150,7 +153,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks
def create_room(self):
result = yield self.hs.get_handlers().room_creation_handler.create_room(
- Requester(self.user, "", False), {}
+ synapse.types.create_requester(self.user), {}
)
defer.returnValue(result["room_id"])
@@ -182,4 +185,20 @@ class ReplicationResourceCase(unittest.TestCase):
)
response_body = json.loads(response_json)
+ if response_code == 200:
+ self.check_response(response_body)
+
defer.returnValue((response_code, response_body))
+
+ def check_response(self, response_body):
+ for name, stream in response_body.items():
+ self.assertIn("field_names", stream)
+ field_names = stream["field_names"]
+ self.assertIn("rows", stream)
+ for row in stream["rows"]:
+ self.assertEquals(
+ len(row), len(field_names),
+ "%s: len(row = %r) == len(field_names = %r)" % (
+ name, row, field_names
+ )
+ )
|