diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 758ee071a5..4cbe9784ed 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
- user_id1 = "@user1:server"
- user_id2 = "@user2:server"
+ user_id1 = "@user1:test"
+ user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b68e9fe082..b1b037006d 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
- self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
- "get_invited_rooms_for_user",
+ "get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 1d14e77255..e96ad4ca4e 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -73,6 +73,6 @@ class TestReplicationClientHandler(object):
def finished_connecting(self):
pass
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 7a7e898843..f3b4a31e21 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -337,7 +337,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"local_invites",
"room_account_data",
"room_tags",
- "state_groups",
+ # "state_groups", # Current impl leaves orphaned state groups around.
"state_groups_state",
):
count = self.get_success(
@@ -351,8 +351,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
- test_purge_room.skip = "Disabled because it's currently broken"
-
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API.
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0f51895b81..c3facc00eb 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
# Make sure the invite is here.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
@@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
- store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+ store.get_rooms_for_local_user_where_membership_is(
+ invitee_id, [Membership.LEAVE]
+ )
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 661c1f88b9..9c13a13786 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,8 +15,6 @@
# limitations under the License.
import json
-from mock import Mock
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
@@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
- )
- return hs
-
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
self.render(request)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 7840f63fe3..00df0ea68e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
rooms_for_user = self.get_success(
- self.store.get_rooms_for_user_where_membership_is(
+ self.store.get_rooms_for_local_user_where_membership_is(
self.u_alice, [Membership.JOIN]
)
)
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ JsonResource,
+ wrap_html_request_handler,
+)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+ class TestResource(DirectServeResource):
+ callback = None
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self.callback(request)
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def test_good_response(self):
+ def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ body = channel.result["body"]
+ self.assertEqual(body, b"response")
+
+ def test_redirect_exception(self):
+ """
+ If the callback raises a RedirectException, it is turned into a 30x
+ with the right location.
+ """
+
+ def callback(request, **kwargs):
+ raise RedirectException(b"/look/an/eagle", 301)
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"301")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+ def test_redirect_exception_with_cookie(self):
+ """
+ If the callback raises a RedirectException which sets a cookie, that is
+ returned too
+ """
+
+ def callback(request, **kwargs):
+ e = RedirectException(b"/no/over/there", 304)
+ e.cookies.append(b"session=yespls")
+ raise e
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"304")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/no/over/there"])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
new file mode 100644
index 0000000000..0ab0a91483
--- /dev/null
+++ b/tests/util/test_itertools.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util.iterutils import chunk_seq
+
+from tests.unittest import TestCase
+
+
+class ChunkSeqTests(TestCase):
+ def test_short_seq(self):
+ parts = chunk_seq("123", 8)
+
+ self.assertEqual(
+ list(parts), ["123"],
+ )
+
+ def test_long_seq(self):
+ parts = chunk_seq("abcdefghijklmnop", 8)
+
+ self.assertEqual(
+ list(parts), ["abcdefgh", "ijklmnop"],
+ )
+
+ def test_uneven_parts(self):
+ parts = chunk_seq("abcdefghijklmnop", 5)
+
+ self.assertEqual(
+ list(parts), ["abcde", "fghij", "klmno", "p"],
+ )
+
+ def test_empty_input(self):
+ parts = chunk_seq([], 5)
+
+ self.assertEqual(
+ list(parts), [],
+ )
|