diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
deleted file mode 100644
index 3d45da38ab..0000000000
--- a/tests/app/test_frontend_proxy.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.app.generic_worker import GenericWorkerServer
-
-from tests.server import make_request
-from tests.unittest import HomeserverTestCase
-
-
-class FrontendProxyTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- federation_http_client=None, homeserver_to_use=GenericWorkerServer
- )
-
- return hs
-
- def default_config(self):
- c = super().default_config()
- c["worker_app"] = "synapse.app.frontend_proxy"
-
- c["worker_listeners"] = [
- {
- "type": "http",
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
- ]
-
- return c
-
- def test_listen_http_with_presence_enabled(self):
- """
- When presence is on, the stub servlet will not register.
- """
- # Presence is on
- self.hs.config.use_presence = True
-
- # Listen with the config
- self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
-
- # Grab the resource from the site that was told to listen
- self.assertEqual(len(self.reactor.tcpServers), 1)
- site = self.reactor.tcpServers[0][1]
-
- channel = make_request(self.reactor, site, "PUT", "presence/a/status")
-
- # 400 + unrecognised, because nothing is registered
- self.assertEqual(channel.code, 400)
- self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
-
- def test_listen_http_with_presence_disabled(self):
- """
- When presence is off, the stub servlet will register.
- """
- # Presence is off
- self.hs.config.use_presence = False
-
- # Listen with the config
- self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
-
- # Grab the resource from the site that was told to listen
- self.assertEqual(len(self.reactor.tcpServers), 1)
- site = self.reactor.tcpServers[0][1]
-
- channel = make_request(self.reactor, site, "PUT", "presence/a/status")
-
- # 401, because the stub servlet still checks authentication
- self.assertEqual(channel.code, 401)
- self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 0444b26798..b625995d12 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -13,7 +13,7 @@
# limitations under the License.
from unittest.mock import Mock
-from synapse.handlers.cas_handler import CasResponse
+from synapse.handlers.cas import CasResponse
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index c7b0975a19..8796af45ed 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -222,7 +222,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- for i in range(3):
+ for _ in range(3):
event = create_invite()
self.get_success(
self.handler.on_invite_request(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 34d2fc1dfb..a25c89bd5b 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -499,7 +499,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("fetch_error")
# Handle code exchange failure
- from synapse.handlers.oidc_handler import OidcError
+ from synapse.handlers.oidc import OidcError
self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
@@ -583,7 +583,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- from synapse.handlers.oidc_handler import OidcError
+ from synapse.handlers.oidc import OidcError
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
@@ -1126,7 +1126,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url: str,
ui_auth_session_id: str = "",
) -> str:
- from synapse.handlers.oidc_handler import OidcSessionData
+ from synapse.handlers.oidc import OidcSessionData
return self.handler._token_generator.generate_oidc_session_token(
state=state,
@@ -1152,7 +1152,7 @@ async def _make_callback_with_userinfo(
userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success.
"""
- from synapse.handlers.oidc_handler import OidcSessionData
+ from synapse.handlers.oidc import OidcSessionData
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 2d12e82897..61271cd084 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
+from synapse.federation.sender import FederationSender
from synapse.handlers.presence import (
EXTERNAL_PROCESS_EXPIRY,
FEDERATION_PING_INTERVAL,
@@ -471,6 +472,168 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
+class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.presence_handler = hs.get_presence_handler()
+ self.clock = hs.get_clock()
+ self.instance_name = hs.get_instance_name()
+
+ self.queue = self.presence_handler.get_federation_queue()
+
+ def test_send_and_get(self):
+ state1 = UserPresenceState.default("@user1:test")
+ state2 = UserPresenceState.default("@user2:test")
+ state3 = UserPresenceState.default("@user3:test")
+
+ prev_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+ now_token = self.queue.get_current_token(self.instance_name)
+
+ rows, upto_token, limited = self.get_success(
+ self.queue.get_replication_rows("master", prev_token, now_token, 10)
+ )
+
+ self.assertEqual(upto_token, now_token)
+ self.assertFalse(limited)
+
+ expected_rows = [
+ (1, ("dest1", "@user1:test")),
+ (1, ("dest2", "@user1:test")),
+ (1, ("dest1", "@user2:test")),
+ (1, ("dest2", "@user2:test")),
+ (2, ("dest3", "@user3:test")),
+ ]
+
+ self.assertCountEqual(rows, expected_rows)
+
+ def test_send_and_get_split(self):
+ state1 = UserPresenceState.default("@user1:test")
+ state2 = UserPresenceState.default("@user2:test")
+ state3 = UserPresenceState.default("@user3:test")
+
+ prev_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+
+ now_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+ rows, upto_token, limited = self.get_success(
+ self.queue.get_replication_rows("master", prev_token, now_token, 10)
+ )
+
+ self.assertEqual(upto_token, now_token)
+ self.assertFalse(limited)
+
+ expected_rows = [
+ (1, ("dest1", "@user1:test")),
+ (1, ("dest2", "@user1:test")),
+ (1, ("dest1", "@user2:test")),
+ (1, ("dest2", "@user2:test")),
+ ]
+
+ self.assertCountEqual(rows, expected_rows)
+
+ def test_clear_queue_all(self):
+ state1 = UserPresenceState.default("@user1:test")
+ state2 = UserPresenceState.default("@user2:test")
+ state3 = UserPresenceState.default("@user3:test")
+
+ prev_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+ self.reactor.advance(10 * 60 * 1000)
+
+ now_token = self.queue.get_current_token(self.instance_name)
+
+ rows, upto_token, limited = self.get_success(
+ self.queue.get_replication_rows("master", prev_token, now_token, 10)
+ )
+ self.assertEqual(upto_token, now_token)
+ self.assertFalse(limited)
+ self.assertCountEqual(rows, [])
+
+ prev_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+ now_token = self.queue.get_current_token(self.instance_name)
+
+ rows, upto_token, limited = self.get_success(
+ self.queue.get_replication_rows("master", prev_token, now_token, 10)
+ )
+ self.assertEqual(upto_token, now_token)
+ self.assertFalse(limited)
+
+ expected_rows = [
+ (3, ("dest1", "@user1:test")),
+ (3, ("dest2", "@user1:test")),
+ (3, ("dest1", "@user2:test")),
+ (3, ("dest2", "@user2:test")),
+ (4, ("dest3", "@user3:test")),
+ ]
+
+ self.assertCountEqual(rows, expected_rows)
+
+ def test_partially_clear_queue(self):
+ state1 = UserPresenceState.default("@user1:test")
+ state2 = UserPresenceState.default("@user2:test")
+ state3 = UserPresenceState.default("@user3:test")
+
+ prev_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+
+ self.reactor.advance(2 * 60 * 1000)
+
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+ self.reactor.advance(4 * 60 * 1000)
+
+ now_token = self.queue.get_current_token(self.instance_name)
+
+ rows, upto_token, limited = self.get_success(
+ self.queue.get_replication_rows("master", prev_token, now_token, 10)
+ )
+ self.assertEqual(upto_token, now_token)
+ self.assertFalse(limited)
+
+ expected_rows = [
+ (2, ("dest3", "@user3:test")),
+ ]
+ self.assertCountEqual(rows, [])
+
+ prev_token = self.queue.get_current_token(self.instance_name)
+
+ self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+ now_token = self.queue.get_current_token(self.instance_name)
+
+ rows, upto_token, limited = self.get_success(
+ self.queue.get_replication_rows("master", prev_token, now_token, 10)
+ )
+ self.assertEqual(upto_token, now_token)
+ self.assertFalse(limited)
+
+ expected_rows = [
+ (3, ("dest1", "@user1:test")),
+ (3, ("dest2", "@user1:test")),
+ (3, ("dest1", "@user2:test")),
+ (3, ("dest2", "@user2:test")),
+ (4, ("dest3", "@user3:test")),
+ ]
+
+ self.assertCountEqual(rows, expected_rows)
+
+
class PresenceJoinTestCase(unittest.HomeserverTestCase):
"""Tests remote servers get told about presence of users in the room when
they join and when new local users join.
@@ -482,10 +645,17 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "server", federation_http_client=None, federation_sender=Mock()
+ "server",
+ federation_http_client=None,
+ federation_sender=Mock(spec=FederationSender),
)
return hs
+ def default_config(self):
+ config = super().default_config()
+ config["send_federation"] = True
+ return config
+
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
@@ -529,9 +699,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Add a new remote server to the room
self._add_new_user(room_id, "@alice:server2")
- # We shouldn't have sent out any local presence *updates*
- self.federation_sender.send_presence.assert_not_called()
-
# When new server is joined we send it the local users presence states.
# We expect to only see user @test2:server, as @test:server is offline
# and has a zero last_active_ts
@@ -550,7 +717,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.federation_sender.reset_mock()
self._add_new_user(room_id, "@bob:server3")
- self.federation_sender.send_presence.assert_not_called()
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=["server3"], states={expected_state}
)
@@ -595,9 +761,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.reactor.pump([0]) # Wait for presence updates to be handled
- # We shouldn't have sent out any local presence *updates*
- self.federation_sender.send_presence.assert_not_called()
-
# We expect to only send test2 presence to server2 and server3
expected_state = self.get_success(
self.presence_handler.current_state_for_user("@test2:server")
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 9e97185507..ed9a884d76 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -26,6 +26,7 @@ from twisted.web.http import HTTPChannel
from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import (
+ MAX_RESPONSE_SIZE,
MatrixFederationHttpClient,
MatrixFederationRequest,
)
@@ -560,3 +561,61 @@ class FederationClientTests(HomeserverTestCase):
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
+
+ def test_too_big(self):
+ """
+ Test what happens if a huge response is returned from the remote endpoint.
+ """
+
+ test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+ self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it a huge HTTP response
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"\r\n"
+ )
+
+ self.pump()
+
+ # should still be waiting
+ self.assertNoResult(test_d)
+
+ sent = 0
+ chunk_size = 1024 * 512
+ while not test_d.called:
+ protocol.dataReceived(b"a" * chunk_size)
+ sent += chunk_size
+ self.assertLessEqual(sent, MAX_RESPONSE_SIZE)
+
+ self.assertEqual(sent, MAX_RESPONSE_SIZE)
+
+ f = self.failureResultOf(test_d)
+ self.assertIsInstance(f.value, RequestSendFailed)
+
+ self.assertTrue(transport.disconnecting)
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
new file mode 100644
index 0000000000..8c13b4f693
--- /dev/null
+++ b/tests/http/test_site.py
@@ -0,0 +1,83 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet.address import IPv6Address
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.app.homeserver import SynapseHomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SynapseRequestTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
+
+ def test_large_request(self):
+ """overlarge HTTP requests should be rejected"""
+ self.hs.start_listening()
+
+ # find the HTTP server which is configured to listen on port 0
+ (port, factory, _backlog, interface) = self.reactor.tcpServers[0]
+ self.assertEqual(interface, "::")
+ self.assertEqual(port, 0)
+
+ # as a control case, first send a regular request.
+
+ # complete the connection and wire it up to a fake transport
+ client_address = IPv6Address("TCP", "::1", "2345")
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+
+ while not transport.disconnecting:
+ self.reactor.advance(1)
+
+ # we should get a 404
+ self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
+
+ # now send an oversized request
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ )
+
+ # we deliberately send all the data in one big chunk, to ensure that
+ # twisted isn't buffering the data in the chunked transfer decoder.
+ # we start with the chunk size, in hex. (We won't actually send this much)
+ protocol.dataReceived(b"10000000\r\n")
+ sent = 0
+ while not transport.disconnected:
+ self.assertLess(sent, 0x10000000, "connection did not drop")
+ protocol.dataReceived(b"\0" * 1024)
+ sent += 1024
+
+ # default max upload size is 50M, so it should drop on the next buffer after
+ # that.
+ self.assertEqual(sent, 50 * 1024 * 1024 + 1024)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c9d04aef29..624bd1b927 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Dict, List, Optional, Tuple
-from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
-from twisted.internet.task import LoopingCall
-from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
-from twisted.web.server import Request, Site
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
@@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
+ channel = self.site.buildProtocol(None)
+
+ # hook into the channel's request factory so that we can keep a record
+ # of the requests
+ requests: List[SynapseRequest] = []
+ real_request_factory = channel.requestFactory
+
+ def request_factory(*args, **kwargs):
+ request = real_request_factory(*args, **kwargs)
+ requests.append(request)
+ return request
+
+ channel.requestFactory = request_factory
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return channel.request
+ # there should have been exactly one request
+ self.assertEqual(len(requests), 1)
+
+ return requests[0]
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
+ max_request_body_size=4096,
+ reactor=self.reactor,
)
if worker_hs.config.redis.redis_enabled:
@@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
+ channel = self._hs_to_site[hs].buildProtocol(None)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
self.received_rdata_rows.append((stream_name, token, r))
-class _PushHTTPChannel(HTTPChannel):
- """A HTTPChannel that wraps pull producers to push producers.
-
- This is a hack to get around the fact that HTTPChannel transparently wraps a
- pull producer (which is what Synapse uses to reply to requests) with
- `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
- uses the standard reactor rather than letting us use our test reactor, which
- makes it very hard to test.
- """
-
- def __init__(
- self, reactor: IReactorTime, request_factory: Type[Request], site: Site
- ):
- super().__init__()
- self.reactor = reactor
- self.requestFactory = request_factory
- self.site = site
-
- self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
-
- def registerProducer(self, producer, streaming):
- # Convert pull producers to push producer.
- if not streaming:
- self._pull_to_push_producer = _PullToPushProducer(
- self.reactor, producer, self
- )
- producer = self._pull_to_push_producer
-
- super().registerProducer(producer, True)
-
- def unregisterProducer(self):
- if self._pull_to_push_producer:
- # We need to manually stop the _PullToPushProducer.
- self._pull_to_push_producer.stop()
-
- def checkPersistence(self, request, version):
- """Check whether the connection can be re-used"""
- # We hijack this to always say no for ease of wiring stuff up in
- # `handle_http_replication_attempt`.
- request.responseHeaders.setRawHeaders(b"connection", [b"close"])
- return False
-
- def requestDone(self, request):
- # Store the request for inspection.
- self.request = request
- super().requestDone(request)
-
-
-class _PullToPushProducer:
- """A push producer that wraps a pull producer."""
-
- def __init__(
- self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
- ):
- self._clock = Clock(reactor)
- self._producer = producer
- self._consumer = consumer
-
- # While running we use a looping call with a zero delay to call
- # resumeProducing on given producer.
- self._looping_call = None # type: Optional[LoopingCall]
-
- # We start writing next reactor tick.
- self._start_loop()
-
- def _start_loop(self):
- """Start the looping call to"""
-
- if not self._looping_call:
- # Start a looping call which runs every tick.
- self._looping_call = self._clock.looping_call(self._run_once, 0)
-
- def stop(self):
- """Stops calling resumeProducing."""
- if self._looping_call:
- self._looping_call.stop()
- self._looping_call = None
-
- def pauseProducing(self):
- """Implements IPushProducer"""
- self.stop()
-
- def resumeProducing(self):
- """Implements IPushProducer"""
- self._start_loop()
-
- def stopProducing(self):
- """Implements IPushProducer"""
- self.stop()
- self._producer.stopProducing()
-
- def _run_once(self):
- """Calls resumeProducing on producer once."""
-
- try:
- self._producer.resumeProducing()
- except Exception:
- logger.exception("Failed to call resumeProducing")
- try:
- self._consumer.unregisterProducer()
- except Exception:
- pass
-
- self.stopProducing()
-
-
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 323237c1bb..f51fa0a79e 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -239,7 +239,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
- for stream_name, token, row in received_rows:
+ for stream_name, _, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
@@ -356,7 +356,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
- for j in range(STATES_PER_USER + 1):
+ for _ in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index ecbee30bb5..120730b764 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -430,7 +430,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
# Create devices
number_devices = 5
- for n in range(number_devices):
+ for _ in range(number_devices):
self.login("user", "pass")
# Get devices
@@ -547,7 +547,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
# Create devices
number_devices = 5
- for n in range(number_devices):
+ for _ in range(number_devices):
self.login("user", "pass")
# Get devices
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 8c66da3af4..29341bc6e9 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -48,22 +48,22 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
# Two rooms and two users. Every user sends and reports every room event
- for i in range(5):
+ for _ in range(5):
self._create_event_and_report(
room_id=self.room_id1,
user_tok=self.other_user_tok,
)
- for i in range(5):
+ for _ in range(5):
self._create_event_and_report(
room_id=self.room_id2,
user_tok=self.other_user_tok,
)
- for i in range(5):
+ for _ in range(5):
self._create_event_and_report(
room_id=self.room_id1,
user_tok=self.admin_user_tok,
)
- for i in range(5):
+ for _ in range(5):
self._create_event_and_report(
room_id=self.room_id2,
user_tok=self.admin_user_tok,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 6bcd997085..6b84188120 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -615,7 +615,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Create 3 test rooms
total_rooms = 3
room_ids = []
- for x in range(total_rooms):
+ for _ in range(total_rooms):
room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
@@ -679,7 +679,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Create 5 test rooms
total_rooms = 5
room_ids = []
- for x in range(total_rooms):
+ for _ in range(total_rooms):
room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
@@ -1577,7 +1577,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
- for i, found_event in enumerate(channel.json_body["events_before"]):
+ for found_event in channel.json_body["events_before"]:
for j, posted_event in enumerate(events):
if found_event["event_id"] == posted_event["event_id"]:
self.assertTrue(j < midway)
@@ -1585,7 +1585,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
else:
self.fail("Event %s from events_before not found" % j)
- for i, found_event in enumerate(channel.json_body["events_after"]):
+ for found_event in channel.json_body["events_after"]:
for j, posted_event in enumerate(events):
if found_event["event_id"] == posted_event["event_id"]:
self.assertTrue(j > midway)
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 363bdeeb2d..79cac4266b 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -467,7 +467,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
number_media: Number of media to be created for the user
"""
upload_resource = self.media_repo.children[b"upload"]
- for i in range(number_media):
+ for _ in range(number_media):
# file size is 67 Byte
image_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 2844c493fc..d599a4c984 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,7 +18,7 @@ import json
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import UserTypes
@@ -54,8 +54,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
- self.secrets = Mock()
-
self.hs = self.setup_test_homeserver()
self.hs.config.registration_shared_secret = "shared"
@@ -84,14 +82,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
Calling GET on the endpoint will return a randomised nonce, using the
homeserver's secrets provider.
"""
- secrets = Mock()
- secrets.token_hex = Mock(return_value="abcd")
-
- self.hs.get_secrets = Mock(return_value=secrets)
+ with patch("secrets.token_hex") as token_hex:
+ # Patch secrets.token_hex for the duration of this context
+ token_hex.return_value = "abcd"
- channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
- self.assertEqual(channel.json_body, {"nonce": "abcd"})
+ self.assertEqual(channel.json_body, {"nonce": "abcd"})
def test_expired_nonce(self):
"""
@@ -1937,7 +1934,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
# Create rooms and join
other_user_tok = self.login("user", "pass")
number_rooms = 5
- for n in range(number_rooms):
+ for _ in range(number_rooms):
self.helper.create_room_as(self.other_user, tok=other_user_tok)
# Get rooms
@@ -2517,7 +2514,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
user_token: Access token of the user
number_media: Number of media to be created for the user
"""
- for i in range(number_media):
+ for _ in range(number_media):
# file size is 67 Byte
image_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 3a050659ca..409f3949dc 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from twisted.internet import defer
+from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence
from synapse.types import UserID
@@ -32,7 +33,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- presence_handler = Mock()
+ presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None)
hs = self.setup_test_homeserver(
@@ -59,12 +60,12 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
+ @unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self):
"""
PUT to the status endpoint with use_presence disabled will NOT call
set_state on the presence handler.
"""
- self.hs.config.use_presence = False
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 92babf65e0..a3694f3d02 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -646,7 +646,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
def test_invites_by_users_ratelimit(self):
"""Tests that invites to a specific user are actually rate-limited."""
- for i in range(3):
+ for _ in range(3):
room_id = self.helper.create_room_as(self.user_id)
self.helper.invite(room_id, self.user_id, "@other-users:red")
@@ -668,7 +668,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
- for i in range(3):
+ for _ in range(3):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@@ -733,7 +733,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
- for i in range(4):
+ for _ in range(4):
channel = self.make_request("POST", path % room_id, {})
self.assertEquals(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 054d4e4140..1cad5f00eb 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -310,6 +310,57 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(channel.json_body.get("sid"))
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://test_server",
+ "email": {
+ "smtp_host": "mail_server",
+ "smtp_port": 2525,
+ "notif_from": "sender@host",
+ },
+ }
+ )
+ def test_reject_invalid_email(self):
+ """Check that bad emails are rejected"""
+
+ # Test for email with multiple @
+ channel = self.make_request(
+ "POST",
+ b"register/email/requestToken",
+ {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
+ )
+ self.assertEquals(400, channel.code, channel.result)
+ # Check error to ensure that we're not erroring due to a bug in the test.
+ self.assertEquals(
+ channel.json_body,
+ {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
+ )
+
+ # Test for email with no @
+ channel = self.make_request(
+ "POST",
+ b"register/email/requestToken",
+ {"client_secret": "foobar", "email": "email", "send_attempt": 1},
+ )
+ self.assertEquals(400, channel.code, channel.result)
+ self.assertEquals(
+ channel.json_body,
+ {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
+ )
+
+ # Test for super long email
+ email = "a@" + "a" * 1000
+ channel = self.make_request(
+ "POST",
+ b"register/email/requestToken",
+ {"client_secret": "foobar", "email": email, "send_attempt": 1},
+ )
+ self.assertEquals(400, channel.code, channel.result)
+ self.assertEquals(
+ channel.json_body,
+ {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
+ )
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
@@ -492,8 +543,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- # Move 6 days forward. This should trigger a renewal email to be sent.
- self.reactor.advance(datetime.timedelta(days=6).total_seconds())
+ # Move 5 days forward. This should trigger a renewal email to be sent.
+ self.reactor.advance(datetime.timedelta(days=5).total_seconds())
self.assertEqual(len(self.email_attempts), 1)
# Retrieving the URL from the email is too much pain for now, so we
@@ -504,14 +555,32 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect on a successful renewal.
- expected_html = self.hs.config.account_validity.account_renewed_html_content
+ expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
+ # Move 1 day forward. Try to renew with the same token again.
+ url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+ channel = self.make_request(b"GET", url)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+ # Check that the HTML we're getting is the one we expect when reusing a
+ # token. The account expiration date should not have changed.
+ expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -531,15 +600,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect when using an
# invalid/unknown token.
- expected_html = self.hs.config.account_validity.invalid_token_html_content
+ expected_html = (
+ self.hs.config.account_validity.account_validity_invalid_token_template.render()
+ )
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -647,7 +715,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
config["account_validity"] = {"enabled": False}
self.hs = self.setup_test_homeserver(config=config)
- self.hs.config.account_validity.period = self.validity_period
+
+ # We need to set these directly, instead of in the homeserver config dict above.
+ # This is due to account validity-related config options not being read by
+ # Synapse when account_validity.enabled is False.
+ self.hs.get_datastore()._account_validity_period = self.validity_period
+ self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
self.store = self.hs.get_datastore()
diff --git a/tests/server.py b/tests/server.py
index b535a5d886..9df8cda24f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -603,12 +603,6 @@ class FakeTransport:
if self.disconnected:
return
- if not hasattr(self.other, "transport"):
- # the other has no transport yet; reschedule
- if self.autoflush:
- self._reactor.callLater(0.0, self.flush)
- return
-
if maxbytes is not None:
to_write = self.buffer[:maxbytes]
else:
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 6339a43f0c..200b9198f9 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import secrets
from tests import unittest
@@ -21,7 +22,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore()
- self.table_name = "table_" + hs.get_secrets().token_hex(6)
+ self.table_name = "table_" + secrets.token_hex(6)
self.get_success(
self.storage.db_pool.runInteraction(
"create",
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 397e68fe0a..088fbb247b 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -38,12 +38,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
last_event = None
# Make a real event chain
- for i in range(event_count):
+ for _ in range(event_count):
ev = self.create_and_send_event(room_id, user, False, last_event)
last_event = [ev]
# Sprinkle in some extremities
- for i in range(extrems):
+ for _ in range(extrems):
ev = self.create_and_send_event(room_id, user, False, last_event)
# Let it run for a while, then pull out the statistics from the
diff --git a/tests/test_server.py b/tests/test_server.py
index 55cde7f62f..407e172e41 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}),
self.resource,
"1.0",
+ max_request_body_size=1234,
+ reactor=self.reactor,
)
# render the request and return the channel
diff --git a/tests/unittest.py b/tests/unittest.py
index d890ad981f..74db7c08f1 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@ import hashlib
import hmac
import inspect
import logging
+import secrets
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from unittest.mock import Mock, patch
@@ -133,7 +134,7 @@ class TestCase(unittest.TestCase):
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
- for (key, value) in attrs.items():
+ for key in attrs.keys():
if not hasattr(obj, key):
raise AssertionError("Expected obj to have a '.%s'" % key)
try:
@@ -247,6 +248,8 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
+ max_request_body_size=1234,
+ reactor=self.reactor,
)
from tests.rest.client.v1.utils import RestHelper
@@ -624,7 +627,6 @@ class HomeserverTestCase(TestCase):
str: The new event's ID.
"""
event_creator = self.hs.get_event_creation_handler()
- secrets = self.hs.get_secrets()
requester = create_requester(user)
event, context = self.get_success(
diff --git a/tests/utils.py b/tests/utils.py
index af6b32fc66..6bd008dcfe 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -153,6 +153,10 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_invites": {
+ "per_room": {"per_second": 10000, "burst_count": 10000},
+ "per_user": {"per_second": 10000, "burst_count": 10000},
+ },
"rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False,
"public_baseurl": None,
@@ -303,7 +307,7 @@ def setup_test_homeserver(
# database for a few more seconds due to flakiness, preventing
# us from dropping it when the test is over. If we can't drop
# it, warn and move on.
- for x in range(5):
+ for _ in range(5):
try:
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
|