diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 89a3f95c0a..bb867150f4 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -323,6 +323,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"renew_at": 172800000, # Time in ms for 2 days
"renew_by_email_enabled": True,
"renew_email_subject": "Renew your account",
+ "account_renewed_html_path": "account_renewed.html",
+ "invalid_token_html_path": "invalid_token.html",
}
# Email config.
@@ -373,6 +375,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ # Check that we're getting HTML back.
+ content_type = None
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type = header[1]
+ self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+
+ # Check that the HTML we're getting is the one we expect on a successful renewal.
+ expected_html = self.hs.config.account_validity.account_renewed_html_content
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
# Move 3 days forward. If the renewal failed, every authed request with
# our access token should be denied from now, otherwise they should
# succeed.
@@ -381,6 +396,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_renewal_invalid_token(self):
+ # Hit the renewal endpoint with an invalid token and check that it behaves as
+ # expected, i.e. that it responds with 404 Not Found and the correct HTML.
+ url = "/_matrix/client/unstable/account_validity/renew?token=123"
+ request, channel = self.make_request(b"GET", url)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"404", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = None
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type = header[1]
+ self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+
+ # Check that the HTML we're getting is the one we expect when using an
+ # invalid/unknown token.
+ expected_html = self.hs.config.account_validity.invalid_token_html_content
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
def test_manual_email_send(self):
self.email_attempts = []
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 8488b6edc8..d961b81d48 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -17,6 +17,8 @@
from mock import Mock
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
event.unsigned["redacted_because"],
)
+
+ def test_circular_redaction(self):
+ redaction_event_id1 = "$redaction1_id:test"
+ redaction_event_id2 = "$redaction2_id:test"
+
+ class EventIdManglingBuilder:
+ def __init__(self, base_builder, event_id):
+ self._base_builder = base_builder
+ self._event_id = event_id
+
+ @defer.inlineCallbacks
+ def build(self, prev_event_ids):
+ built_event = yield self._base_builder.build(prev_event_ids)
+ built_event.event_id = self._event_id
+ built_event._event_dict["event_id"] = self._event_id
+ return built_event
+
+ @property
+ def room_id(self):
+ return self._base_builder.room_id
+
+ event_1, context_1 = self.get_success(
+ self.event_creation_handler.create_new_client_event(
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id2,
+ },
+ ),
+ redaction_event_id1,
+ )
+ )
+ )
+
+ self.get_success(self.store.persist_event(event_1, context_1))
+
+ event_2, context_2 = self.get_success(
+ self.event_creation_handler.create_new_client_event(
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id1,
+ },
+ ),
+ redaction_event_id2,
+ )
+ )
+ )
+ self.get_success(self.store.persist_event(event_2, context_2))
+
+ # fetch one of the redactions
+ fetched = self.get_success(self.store.get_event(redaction_event_id1))
+
+ # it should have been redacted
+ self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
+ self.assertEqual(
+ fetched.unsigned["redacted_because"].event_id, redaction_event_id2
+ )
|