From ec10bdd32bb52af73789f5f60b39135578a739b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Oct 2020 15:09:31 +0100 Subject: Speed up unit tests when using PostgreSQL (#8450) --- tests/server.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tests/server.py') diff --git a/tests/server.py b/tests/server.py index b404ad4e2a..f7f5276b21 100644 --- a/tests/server.py +++ b/tests/server.py @@ -372,6 +372,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool.threadpool = ThreadPool(clock._reactor) pool.running = True + # We've just changed the Databases to run DB transactions on the same + # thread, so we need to disable the dedicated thread behaviour. + server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server -- cgit 1.5.1 From 9789b1fba541a5ae01b946770416729e5b7e5b7e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 9 Oct 2020 17:22:25 +0100 Subject: Fix threadsafety in ThreadedMemoryReactorClock (#8497) This could, very occasionally, cause: ``` tests.test_visibility.FilterEventsForServerTestCase.test_large_room =============================================================================== [ERROR] Traceback (most recent call last): File "/src/tests/rest/media/v1/test_media_storage.py", line 86, in test_ensure_media_is_in_local_cache self.wait_on_thread(x) File "/src/tests/unittest.py", line 296, in wait_on_thread self.reactor.advance(0.01) File "/src/.tox/py35/lib/python3.5/site-packages/twisted/internet/task.py", line 826, in advance self._sortCalls() File "/src/.tox/py35/lib/python3.5/site-packages/twisted/internet/task.py", line 787, in _sortCalls self.calls.sort(key=lambda a: a.getTime()) builtins.ValueError: list modified during sort tests.rest.media.v1.test_media_storage.MediaStorageTests.test_ensure_media_is_in_local_cache ``` --- changelog.d/8497.misc | 1 + tests/server.py | 36 ++++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8497.misc (limited to 'tests/server.py') diff --git a/changelog.d/8497.misc b/changelog.d/8497.misc new file mode 100644 index 0000000000..8bc05e8df6 --- /dev/null +++ b/changelog.d/8497.misc @@ -0,0 +1 @@ +Fix a threadsafety bug in unit tests. diff --git a/tests/server.py b/tests/server.py index f7f5276b21..422c8b42ca 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,8 +1,11 @@ import json import logging +from collections import deque from io import SEEK_END, BytesIO +from typing import Callable import attr +from typing_extensions import Deque from zope.interface import implementer from twisted.internet import address, threads, udp @@ -251,6 +254,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} + self._thread_callbacks = deque() # type: Deque[Callable[[], None]]() @implementer(IResolverSimple) class FakeResolver: @@ -272,10 +276,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ Make the callback fire in the next reactor iteration. """ - d = Deferred() - d.addCallback(lambda x: callback(*args, **kwargs)) - self.callLater(0, d.callback, True) - return d + cb = lambda: callback(*args, **kwargs) + # it's not safe to call callLater() here, so we append the callback to a + # separate queue. + self._thread_callbacks.append(cb) def getThreadPool(self): return self.threadpool @@ -303,6 +307,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return conn + def advance(self, amount): + # first advance our reactor's time, and run any "callLater" callbacks that + # makes ready + super().advance(amount) + + # now run any "callFromThread" callbacks + while True: + try: + callback = self._thread_callbacks.popleft() + except IndexError: + break + callback() + + # check for more "callLater" callbacks added by the thread callback + # This isn't required in a regular reactor, but it ends up meaning that + # our database queries can complete in a single call to `advance` [1] which + # simplifies tests. + # + # [1]: we replace the threadpool backing the db connection pool with a + # mock ThreadPool which doesn't really use threads; but we still use + # reactor.callFromThread to feed results back from the db functions to the + # main thread. + super().advance(0) + class ThreadPool: """ -- cgit 1.5.1 From d35a451399d5bb15ba0b452c26719474371298d7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Oct 2020 14:19:29 -0400 Subject: Clean-up some broken/unused code in the test framework (#8514) --- changelog.d/8514.misc | 1 + tests/server.py | 2 - tests/utils.py | 122 ++++++++++++++++++++++---------------------------- 3 files changed, 55 insertions(+), 70 deletions(-) create mode 100644 changelog.d/8514.misc (limited to 'tests/server.py') diff --git a/changelog.d/8514.misc b/changelog.d/8514.misc new file mode 100644 index 0000000000..0e7ac4f220 --- /dev/null +++ b/changelog.d/8514.misc @@ -0,0 +1 @@ +Remove unused code from the test framework. diff --git a/tests/server.py b/tests/server.py index 422c8b42ca..4d33b84097 100644 --- a/tests/server.py +++ b/tests/server.py @@ -367,8 +367,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): """ server = _sth(cleanup_func, *args, **kwargs) - database = server.config.database.get_single_database() - # Make the thread pool synchronous. clock = server.get_clock() diff --git a/tests/utils.py b/tests/utils.py index af563ffe0f..0c09f5457f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -192,7 +192,6 @@ class TestHomeServer(HomeServer): def setup_test_homeserver( cleanup_func, name="test", - datastore=None, config=None, reactor=None, homeserverToUse=TestHomeServer, @@ -249,7 +248,7 @@ def setup_test_homeserver( # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() - if datastore is None and isinstance(db_engine, PostgresEngine): + if isinstance(db_engine, PostgresEngine): db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER, @@ -265,79 +264,66 @@ def setup_test_homeserver( cur.close() db_conn.close() - if datastore is None: - hs = homeserverToUse( - name, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs - ) + hs = homeserverToUse( + name, + config=config, + version_string="Synapse/tests", + tls_server_context_factory=Mock(), + tls_client_options_factory=Mock(), + reactor=reactor, + **kargs + ) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_background_tasks() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_background_tasks() - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 - # Close all the db pools - database._db_pool.close() + # Close all the db pools + database._db_pool.close() - dropped = False + dropped = False - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for x in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() - else: - hs = homeserverToUse( - name, - datastore=datastore, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs - ) + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for x in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) # bcrypt is far too slow to be doing in unit tests # Need to let the HS build an auth handler and then mess with it -- cgit 1.5.1 From aff1eb7c671b0a3813407321d2702ec46c71fa56 Mon Sep 17 00:00:00 2001 From: Dan Callahan Date: Tue, 27 Oct 2020 23:26:36 +0000 Subject: Tell Black to format code for Python 3.5 (#8664) This allows trailing commas in multi-line arg lists. Minor, but we might as well keep our formatting current with regard to our minimum supported Python version. Signed-off-by: Dan Callahan --- changelog.d/8664.misc | 1 + pyproject.toml | 2 +- synapse/http/client.py | 2 +- synapse/storage/database.py | 4 ++-- synapse/util/retryutils.py | 2 +- tests/replication/_base.py | 2 +- tests/replication/tcp/streams/test_events.py | 2 +- tests/server.py | 4 ++-- tests/storage/test_client_ips.py | 2 +- tests/test_utils/event_injection.py | 2 +- 10 files changed, 12 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8664.misc (limited to 'tests/server.py') diff --git a/changelog.d/8664.misc b/changelog.d/8664.misc new file mode 100644 index 0000000000..278cf53adc --- /dev/null +++ b/changelog.d/8664.misc @@ -0,0 +1 @@ +Tell Black to format code for Python 3.5. diff --git a/pyproject.toml b/pyproject.toml index db4a2e41e4..cd880d4e39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py34'] +target-version = ['py35'] exclude = ''' ( diff --git a/synapse/http/client.py b/synapse/http/client.py index 8324632cb6..f409368802 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -359,7 +359,7 @@ class SimpleHttpClient: agent=self.agent, data=body_producer, headers=headers, - **self._extra_treq_args + **self._extra_treq_args, ) # type: defer.Deferred # we use our own timeout mechanism rather than treq's as a workaround diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0217e63108..a0572b2952 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ def make_pool( cp_openfun=lambda conn: engine.on_new_connection( LoggingDatabaseConnection(conn, engine, "on_new_connection") ), - **db_config.config.get("args", {}) + **db_config.config.get("args", {}), ) @@ -632,7 +632,7 @@ class DatabasePool: func, *args, db_autocommit=db_autocommit, - **kwargs + **kwargs, ) for after_callback, after_args, after_kwargs in after_callbacks: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index a5cc9d0551..4ab379e429 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k failure_ts, retry_interval, backoff_on_failure=backoff_on_failure, - **kwargs + **kwargs, ) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 093e2faac7..f1e53f33cd 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -269,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, - **kwargs + **kwargs, ) # If the instance is in the `instance_map` config then workers may try diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index c9998e88e6..bad0df08cf 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase): sender=sender, type="test_event", content={"body": body}, - **kwargs + **kwargs, ) ) diff --git a/tests/server.py b/tests/server.py index 4d33b84097..ea9c22bc51 100644 --- a/tests/server.py +++ b/tests/server.py @@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runWithConnection, func, *args, - **kwargs + **kwargs, ) def runInteraction(interaction, *args, **kwargs): @@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runInteraction, interaction, *args, - **kwargs + **kwargs, ) pool.runWithConnection = runWithConnection diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 755c70db31..e96ca1c8ca 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, - **make_request_args + **make_request_args, ) request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e93aa84405..c3c4a93e1f 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -50,7 +50,7 @@ async def inject_member_event( sender=sender, state_key=target, content=content, - **kwargs + **kwargs, ) -- cgit 1.5.1 From 00b24aa545091395f9a92d531836f6bf7b4460e0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Oct 2020 07:27:37 -0400 Subject: Support generating structured logs in addition to standard logs. (#8607) This modifies the configuration of structured logging to be usable from the standard Python logging configuration. This also separates the formatting of logs from the transport allowing JSON logs to files or standard logs to sockets. --- UPGRADE.rst | 16 ++ changelog.d/8607.misc | 1 + docs/sample_log_config.yaml | 4 + docs/structured_logging.md | 164 ++++++++++++----- scripts-dev/lint.sh | 2 +- synapse/config/logger.py | 96 +++++----- synapse/logging/__init__.py | 20 +++ synapse/logging/_remote.py | 97 ++++++----- synapse/logging/_structured.py | 329 ++++++----------------------------- synapse/logging/_terse_json.py | 192 ++++++-------------- synapse/logging/filter.py | 33 ++++ synmark/__init__.py | 39 ----- synmark/__main__.py | 6 +- synmark/suites/logging.py | 60 ++++--- tests/logging/__init__.py | 34 ++++ tests/logging/test_remote_handler.py | 153 ++++++++++++++++ tests/logging/test_structured.py | 214 ----------------------- tests/logging/test_terse_json.py | 253 ++++++++------------------- tests/server.py | 4 +- 19 files changed, 706 insertions(+), 1011 deletions(-) create mode 100644 changelog.d/8607.misc create mode 100644 synapse/logging/filter.py create mode 100644 tests/logging/test_remote_handler.py delete mode 100644 tests/logging/test_structured.py (limited to 'tests/server.py') diff --git a/UPGRADE.rst b/UPGRADE.rst index 5a68312217..960c2aeb2b 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,22 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.23.0 +==================== + +Structured logging configuration breaking changes +------------------------------------------------- + +This release deprecates use of the ``structured: true`` logging configuration for +structured logging. If your logging configuration contains ``structured: true`` +then it should be modified based on the `structured logging documentation +`_. + +The ``structured`` and ``drains`` logging options are now deprecated and should +be replaced by standard logging configuration of ``handlers`` and ``formatters`. + +A future will release of Synapse will make using ``structured: true`` an error. + Upgrading to v1.22.0 ==================== diff --git a/changelog.d/8607.misc b/changelog.d/8607.misc new file mode 100644 index 0000000000..9e56551a34 --- /dev/null +++ b/changelog.d/8607.misc @@ -0,0 +1 @@ +Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting. diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index e26657f9fe..ff3c747180 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -3,7 +3,11 @@ # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. # +# Synapse also supports structured logging for machine readable logs which can +# be ingested by ELK stacks. See [2] for details. +# # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md version: 1 diff --git a/docs/structured_logging.md b/docs/structured_logging.md index decec9b8fa..b1281667e0 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -1,83 +1,161 @@ # Structured Logging -A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack". +A structured logging system can be useful when your logs are destined for a +machine to parse and process. By maintaining its machine-readable characteristics, +it enables more efficient searching and aggregations when consumed by software +such as the "ELK stack". -Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to). +Synapse's structured logging system is configured via the file that Synapse's +`log_config` config option points to. The file should include a formatter which +uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a +handler which uses the above formatter. + +There is also a `synapse.logging.JsonFormatter` option which does not include +a timestamp in the resulting JSON. This is useful if the log ingester adds its +own timestamp. A structured logging configuration looks similar to the following: ```yaml -structured: true +version: 1 + +formatters: + structured: + class: synapse.logging.TerseJsonFormatter + +handlers: + file: + class: logging.handlers.TimedRotatingFileHandler + formatter: structured + filename: /path/to/my/logs/homeserver.log + when: midnight + backupCount: 3 # Does not include the current log file. + encoding: utf8 loggers: synapse: level: INFO + handlers: [remote] synapse.storage.SQL: level: WARNING - -drains: - console: - type: console - location: stdout - file: - type: file_json - location: homeserver.log ``` -The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON). - -## Drain Types +The above logging config will set Synapse as 'INFO' logging level by default, +with the SQL layer at 'WARNING', and will log to a file, stored as JSON. -Drain types can be specified by the `type` key. +It is also possible to figure Synapse to log to a remote endpoint by using the +`synapse.logging.RemoteHandler` class included with Synapse. It takes the +following arguments: -### `console` +- `host`: Hostname or IP address of the log aggregator. +- `port`: Numerical port to contact on the host. +- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow. -Outputs human-readable logs to the console. +A remote structured logging configuration looks similar to the following: -Arguments: +```yaml +version: 1 -- `location`: Either `stdout` or `stderr`. +formatters: + structured: + class: synapse.logging.TerseJsonFormatter -### `console_json` +handlers: + remote: + class: synapse.logging.RemoteHandler + formatter: structured + host: 10.1.2.3 + port: 9999 -Outputs machine-readable JSON logs to the console. +loggers: + synapse: + level: INFO + handlers: [remote] + synapse.storage.SQL: + level: WARNING +``` -Arguments: +The above logging config will set Synapse as 'INFO' logging level by default, +with the SQL layer at 'WARNING', and will log JSON formatted messages to a +remote endpoint at 10.1.2.3:9999. -- `location`: Either `stdout` or `stderr`. +## Upgrading from legacy structured logging configuration -### `console_json_terse` +Versions of Synapse prior to v1.23.0 included a custom structured logging +configuration which is deprecated. It used a `structured: true` flag and +configured `drains` instead of ``handlers`` and `formatters`. -Outputs machine-readable JSON logs to the console, separated by newlines. This -format is not designed to be read and re-formatted into human-readable text, but -is optimal for a logging aggregation system. +Synapse currently automatically converts the old configuration to the new +configuration, but this will be removed in a future version of Synapse. The +following reference can be used to update your configuration. Based on the drain +`type`, we can pick a new handler: -Arguments: +1. For a type of `console`, `console_json`, or `console_json_terse`: a handler + with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` + or `ext://sys.stderr` should be used. +2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with + a location of the file path should be used. +3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler` + with the host and port should be used. -- `location`: Either `stdout` or `stderr`. +Then based on the drain `type` we can pick a new formatter: -### `file` +1. For a type of `console` or `file` no formatter is necessary. +2. For a type of `console_json` or `file_json`: a formatter of + `synapse.logging.JsonFormatter` should be used. +3. For a type of `console_json_terse` or `network_json_terse`: a formatter of + `synapse.logging.TerseJsonFormatter` should be used. -Outputs human-readable logs to a file. +For each new handler and formatter they should be added to the logging configuration +and then assigned to either a logger or the root logger. -Arguments: +An example legacy configuration: -- `location`: An absolute path to the file to log to. +```yaml +structured: true -### `file_json` +loggers: + synapse: + level: INFO + synapse.storage.SQL: + level: WARNING -Outputs machine-readable logs to a file. +drains: + console: + type: console + location: stdout + file: + type: file_json + location: homeserver.log +``` -Arguments: +Would be converted into a new configuration: -- `location`: An absolute path to the file to log to. +```yaml +version: 1 -### `network_json_terse` +formatters: + json: + class: synapse.logging.JsonFormatter -Delivers machine-readable JSON logs to a log aggregator over TCP. This is -compatible with LogStash's TCP input with the codec set to `json_lines`. +handlers: + console: + class: logging.StreamHandler + location: ext://sys.stdout + file: + class: logging.FileHandler + formatter: json + filename: homeserver.log -Arguments: +loggers: + synapse: + level: INFO + handlers: [console, file] + synapse.storage.SQL: + level: WARNING +``` -- `host`: Hostname or IP address of the log aggregator. -- `port`: Numerical port to contact on the host. \ No newline at end of file +The new logging configuration is a bit more verbose, but significantly more +flexible. It allows for configuration that were not previously possible, such as +sending plain logs over the network, or using different handlers for different +modules. diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index f141805519..f328ab57d5 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -80,7 +80,7 @@ else # then lint everything! if [[ -z ${files+x} ]]; then # Lint all source code files and directories - files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py") + files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark") fi fi diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 6b7be28aee..d4e887a3e0 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -23,7 +23,6 @@ from string import Template import yaml from twisted.logger import ( - ILogObserver, LogBeginner, STDLibLogObserver, eventAsText, @@ -32,11 +31,9 @@ from twisted.logger import ( import synapse from synapse.app import _base as appbase -from synapse.logging._structured import ( - reload_structured_logging, - setup_structured_logging, -) +from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter +from synapse.logging.filter import MetadataFilter from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError @@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template( # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. # +# Synapse also supports structured logging for machine readable logs which can +# be ingested by ELK stacks. See [2] for details. +# # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md version: 1 @@ -176,11 +177,11 @@ class LoggingConfig(Config): log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): +def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None: """ - Set up Python stdlib logging. + Set up Python standard library logging. """ - if log_config is None: + if log_config_path is None: log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" @@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): handler.setFormatter(formatter) logger.addHandler(handler) else: - logging.config.dictConfig(log_config) + # Load the logging configuration. + _load_logging_config(log_config_path) # We add a log record factory that runs all messages through the # LoggingContextFilter so that we get the context *at the time we log* @@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter(request="") + log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() def factory(*args, **kwargs): record = old_factory(*args, **kwargs) - log_filter.filter(record) + log_context_filter.filter(record) + log_metadata_filter.filter(record) return record logging.setLogRecordFactory(factory) @@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") - return observer - -def _reload_stdlib_logging(*args, log_config=None): - logger = logging.getLogger("") +def _load_logging_config(log_config_path: str) -> None: + """ + Configure logging from a log config path. + """ + with open(log_config_path, "rb") as f: + log_config = yaml.safe_load(f.read()) if not log_config: - logger.warning("Reloaded a blank config?") + logging.warning("Loaded a blank logging config?") + + # If the old structured logging configuration is being used, convert it to + # the new style configuration. + if "structured" in log_config and log_config.get("structured"): + log_config = setup_structured_logging(log_config) logging.config.dictConfig(log_config) +def _reload_logging_config(log_config_path): + """ + Reload the log configuration from the file and apply it. + """ + # If no log config path was given, it cannot be reloaded. + if log_config_path is None: + return + + _load_logging_config(log_config_path) + logging.info("Reloaded log config from %s due to SIGHUP", log_config_path) + + def setup_logging( hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner -) -> ILogObserver: +) -> None: """ Set up the logging subsystem. @@ -282,41 +305,18 @@ def setup_logging( logBeginner: The Twisted logBeginner to use. - Returns: - The "root" Twisted Logger observer, suitable for sending logs to from a - Logger instance. """ - log_config = config.worker_log_config if use_worker_options else config.log_config - - def read_config(*args, callback=None): - if log_config is None: - return None - - with open(log_config, "rb") as f: - log_config_body = yaml.safe_load(f.read()) - - if callback: - callback(log_config=log_config_body) - logging.info("Reloaded log config from %s due to SIGHUP", log_config) - - return log_config_body + log_config_path = ( + config.worker_log_config if use_worker_options else config.log_config + ) - log_config_body = read_config() + # Perform one-time logging configuration. + _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner) + # Add a SIGHUP handler to reload the logging configuration, if one is available. + appbase.register_sighup(_reload_logging_config, log_config_path) - if log_config_body and log_config_body.get("structured") is True: - logger = setup_structured_logging( - hs, config, log_config_body, logBeginner=logBeginner - ) - appbase.register_sighup(read_config, callback=reload_structured_logging) - else: - logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner) - appbase.register_sighup(read_config, callback=_reload_stdlib_logging) - - # make sure that the first thing we log is a thing we can grep backwards - # for + # Log immediately so we can grep backwards. logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) logging.info("Instance name: %s", hs.get_instance_name()) - - return logger diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py index e69de29bb2..b28b7b2ef7 100644 --- a/synapse/logging/__init__.py +++ b/synapse/logging/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These are imported to allow for nicer logging configuration files. +from synapse.logging._remote import RemoteHandler +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter + +__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"] diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 0caf325916..ba45424f02 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import sys import traceback from collections import deque @@ -21,6 +22,7 @@ from math import floor from typing import Callable, Optional import attr +from typing_extensions import Deque from zope.interface import implementer from twisted.application.internet import ClientService @@ -32,7 +34,8 @@ from twisted.internet.endpoints import ( ) from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.protocol import Factory, Protocol -from twisted.logger import ILogObserver, Logger, LogLevel + +logger = logging.getLogger(__name__) @attr.s @@ -45,11 +48,11 @@ class LogProducer: Args: buffer: Log buffer to read logs from. transport: Transport to write to. - format_event: A callable to format the log entry to a string. + format: A callable to format the log record to a string. """ transport = attr.ib(type=ITransport) - format_event = attr.ib(type=Callable[[dict], str]) + _format = attr.ib(type=Callable[[logging.LogRecord], str]) _buffer = attr.ib(type=deque) _paused = attr.ib(default=False, type=bool, init=False) @@ -61,16 +64,19 @@ class LogProducer: self._buffer = deque() def resumeProducing(self): + # If we're already producing, nothing to do. self._paused = False + # Loop until paused. while self._paused is False and (self._buffer and self.transport.connected): try: - # Request the next event and format it. - event = self._buffer.popleft() - msg = self.format_event(event) + # Request the next record and format it. + record = self._buffer.popleft() + msg = self._format(record) # Send it as a new line over the transport. self.transport.write(msg.encode("utf8")) + self.transport.write(b"\n") except Exception: # Something has gone wrong writing to the transport -- log it # and break out of the while. @@ -78,60 +84,63 @@ class LogProducer: break -@attr.s -@implementer(ILogObserver) -class TCPLogObserver: +class RemoteHandler(logging.Handler): """ - An IObserver that writes JSON logs to a TCP target. + An logging handler that writes logs to a TCP target. Args: - hs (HomeServer): The homeserver that is being logged for. host: The host of the logging target. port: The logging target's port. - format_event: A callable to format the log entry to a string. maximum_buffer: The maximum buffer size. """ - hs = attr.ib() - host = attr.ib(type=str) - port = attr.ib(type=int) - format_event = attr.ib(type=Callable[[dict], str]) - maximum_buffer = attr.ib(type=int) - _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) - _logger = attr.ib(default=attr.Factory(Logger)) - _producer = attr.ib(default=None, type=Optional[LogProducer]) - - def start(self) -> None: + def __init__( + self, + host: str, + port: int, + maximum_buffer: int = 1000, + level=logging.NOTSET, + _reactor=None, + ): + super().__init__(level=level) + self.host = host + self.port = port + self.maximum_buffer = maximum_buffer + + self._buffer = deque() # type: Deque[logging.LogRecord] + self._connection_waiter = None # type: Optional[Deferred] + self._producer = None # type: Optional[LogProducer] # Connect without DNS lookups if it's a direct IP. + if _reactor is None: + from twisted.internet import reactor + + _reactor = reactor + try: ip = ip_address(self.host) if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) + endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port) elif isinstance(ip, IPv6Address): - endpoint = TCP6ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) + endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) else: raise ValueError("Unknown IP address provided: %s" % (self.host,)) except ValueError: - endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) + endpoint = HostnameEndpoint(_reactor, self.host, self.port) factory = Factory.forProtocol(Protocol) - self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) + self._service = ClientService(endpoint, factory, clock=_reactor) self._service.startService() self._connect() - def stop(self): + def close(self): self._service.stopService() def _connect(self) -> None: """ Triggers an attempt to connect then write to the remote if not already writing. """ + # Do not attempt to open multiple connections. if self._connection_waiter: return @@ -158,9 +167,7 @@ class TCPLogObserver: # Make a new producer and start it. self._producer = LogProducer( - buffer=self._buffer, - transport=r.transport, - format_event=self.format_event, + buffer=self._buffer, transport=r.transport, format=self.format, ) r.transport.registerProducer(self._producer, True) self._producer.resumeProducing() @@ -168,19 +175,19 @@ class TCPLogObserver: def _handle_pressure(self) -> None: """ - Handle backpressure by shedding events. + Handle backpressure by shedding records. The buffer will, in this order, until the buffer is below the maximum: - - Shed DEBUG events - - Shed INFO events - - Shed the middle 50% of the events. + - Shed DEBUG records. + - Shed INFO records. + - Shed the middle 50% of the records. """ if len(self._buffer) <= self.maximum_buffer: return # Strip out DEBUGs self._buffer = deque( - filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer) + filter(lambda record: record.levelno > logging.DEBUG, self._buffer) ) if len(self._buffer) <= self.maximum_buffer: @@ -188,7 +195,7 @@ class TCPLogObserver: # Strip out INFOs self._buffer = deque( - filter(lambda event: event["log_level"] != LogLevel.info, self._buffer) + filter(lambda record: record.levelno > logging.INFO, self._buffer) ) if len(self._buffer) <= self.maximum_buffer: @@ -209,17 +216,17 @@ class TCPLogObserver: self._buffer.extend(reversed(end_buffer)) - def __call__(self, event: dict) -> None: - self._buffer.append(event) + def emit(self, record: logging.LogRecord) -> None: + self._buffer.append(record) # Handle backpressure, if it exists. try: self._handle_pressure() except Exception: - # If handling backpressure fails,clear the buffer and log the + # If handling backpressure fails, clear the buffer and log the # exception. self._buffer.clear() - self._logger.failure("Failed clearing backpressure") + logger.warning("Failed clearing backpressure") # Try and write immediately. self._connect() diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 0fc2ea609e..14d9c104c2 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -12,138 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import os.path -import sys -import typing -import warnings -from typing import List +from typing import Any, Dict, Generator, Optional, Tuple -import attr -from constantly import NamedConstant, Names, ValueConstant, Values -from zope.interface import implementer - -from twisted.logger import ( - FileLogObserver, - FilteringLogObserver, - ILogObserver, - LogBeginner, - Logger, - LogLevel, - LogLevelFilterPredicate, - LogPublisher, - eventAsText, - jsonFileLogObserver, -) +from constantly import NamedConstant, Names from synapse.config._base import ConfigError -from synapse.logging._terse_json import ( - TerseJSONToConsoleLogObserver, - TerseJSONToTCPLogObserver, -) -from synapse.logging.context import current_context - - -def stdlib_log_level_to_twisted(level: str) -> LogLevel: - """ - Convert a stdlib log level to Twisted's log level. - """ - lvl = level.lower().replace("warning", "warn") - return LogLevel.levelWithName(lvl) - - -@attr.s -@implementer(ILogObserver) -class LogContextObserver: - """ - An ILogObserver which adds Synapse-specific log context information. - - Attributes: - observer (ILogObserver): The target parent observer. - """ - - observer = attr.ib() - - def __call__(self, event: dict) -> None: - """ - Consume a log event and emit it to the parent observer after filtering - and adding log context information. - - Args: - event (dict) - """ - # Filter out some useless events that Twisted outputs - if "log_text" in event: - if event["log_text"].startswith("DNSDatagramProtocol starting on "): - return - - if event["log_text"].startswith("(UDP Port "): - return - - if event["log_text"].startswith("Timing out client") or event[ - "log_format" - ].startswith("Timing out client"): - return - - context = current_context() - - # Copy the context information to the log event. - context.copy_to_twisted_log_entry(event) - - self.observer(event) - - -class PythonStdlibToTwistedLogger(logging.Handler): - """ - Transform a Python stdlib log message into a Twisted one. - """ - - def __init__(self, observer, *args, **kwargs): - """ - Args: - observer (ILogObserver): A Twisted logging observer. - *args, **kwargs: Args/kwargs to be passed to logging.Handler. - """ - self.observer = observer - super().__init__(*args, **kwargs) - - def emit(self, record: logging.LogRecord) -> None: - """ - Emit a record to Twisted's observer. - - Args: - record (logging.LogRecord) - """ - - self.observer( - { - "log_time": record.created, - "log_text": record.getMessage(), - "log_format": "{log_text}", - "log_namespace": record.name, - "log_level": stdlib_log_level_to_twisted(record.levelname), - } - ) - - -def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver: - """ - A log observer that formats events like the traditional log formatter and - sends them to `outFile`. - - Args: - outFile (file object): The file object to write to. - """ - - def formatEvent(_event: dict) -> str: - event = dict(_event) - event["log_level"] = event["log_level"].name.upper() - event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + ( - event.get("log_format", "{log_text}") or "{log_text}" - ) - return eventAsText(event, includeSystem=False) + "\n" - - return FileLogObserver(outFile, formatEvent) class DrainType(Names): @@ -155,30 +29,12 @@ class DrainType(Names): NETWORK_JSON_TERSE = NamedConstant() -class OutputPipeType(Values): - stdout = ValueConstant(sys.__stdout__) - stderr = ValueConstant(sys.__stderr__) - - -@attr.s -class DrainConfiguration: - name = attr.ib() - type = attr.ib() - location = attr.ib() - options = attr.ib(default=None) - - -@attr.s -class NetworkJSONTerseOptions: - maximum_buffer = attr.ib(type=int) - - -DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} +DEFAULT_LOGGERS = {"synapse": {"level": "info"}} def parse_drain_configs( drains: dict, -) -> typing.Generator[DrainConfiguration, None, None]: +) -> Generator[Tuple[str, Dict[str, Any]], None, None]: """ Parse the drain configurations. @@ -186,11 +42,12 @@ def parse_drain_configs( drains (dict): A list of drain configurations. Yields: - DrainConfiguration instances. + dict instances representing a logging handler. Raises: ConfigError: If any of the drain configuration items are invalid. """ + for name, config in drains.items(): if "type" not in config: raise ConfigError("Logging drains require a 'type' key.") @@ -202,6 +59,18 @@ def parse_drain_configs( "%s is not a known logging drain type." % (config["type"],) ) + # Either use the default formatter or the tersejson one. + if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): + formatter = "json" # type: Optional[str] + elif logging_type in ( + DrainType.CONSOLE_JSON_TERSE, + DrainType.NETWORK_JSON_TERSE, + ): + formatter = "tersejson" + else: + # A formatter of None implies using the default formatter. + formatter = None + if logging_type in [ DrainType.CONSOLE, DrainType.CONSOLE_JSON, @@ -217,9 +86,11 @@ def parse_drain_configs( % (logging_type,) ) - pipe = OutputPipeType.lookupByName(location).value - - yield DrainConfiguration(name=name, type=logging_type, location=pipe) + yield name, { + "class": "logging.StreamHandler", + "formatter": formatter, + "stream": "ext://sys." + location, + } elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: if "location" not in config: @@ -233,18 +104,25 @@ def parse_drain_configs( "File paths need to be absolute, '%s' is a relative path" % (location,) ) - yield DrainConfiguration(name=name, type=logging_type, location=location) + + yield name, { + "class": "logging.FileHandler", + "formatter": formatter, + "filename": location, + } elif logging_type in [DrainType.NETWORK_JSON_TERSE]: host = config.get("host") port = config.get("port") maximum_buffer = config.get("maximum_buffer", 1000) - yield DrainConfiguration( - name=name, - type=logging_type, - location=(host, port), - options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer), - ) + + yield name, { + "class": "synapse.logging.RemoteHandler", + "formatter": formatter, + "host": host, + "port": port, + "maximum_buffer": maximum_buffer, + } else: raise ConfigError( @@ -253,126 +131,29 @@ def parse_drain_configs( ) -class StoppableLogPublisher(LogPublisher): +def setup_structured_logging(log_config: dict,) -> dict: """ - A log publisher that can tell its observers to shut down any external - communications. - """ - - def stop(self): - for obs in self._observers: - if hasattr(obs, "stop"): - obs.stop() - - -def setup_structured_logging( - hs, - config, - log_config: dict, - logBeginner: LogBeginner, - redirect_stdlib_logging: bool = True, -) -> LogPublisher: - """ - Set up Twisted's structured logging system. - - Args: - hs: The homeserver to use. - config (HomeserverConfig): The configuration of the Synapse homeserver. - log_config (dict): The log configuration to use. + Convert a legacy structured logging configuration (from Synapse < v1.23.0) + to one compatible with the new standard library handlers. """ - if config.no_redirect_stdio: - raise ConfigError( - "no_redirect_stdio cannot be defined using structured logging." - ) - - logger = Logger() - if "drains" not in log_config: raise ConfigError("The logging configuration requires a list of drains.") - observers = [] # type: List[ILogObserver] - - for observer in parse_drain_configs(log_config["drains"]): - # Pipe drains - if observer.type == DrainType.CONSOLE: - logger.debug( - "Starting up the {name} console logger drain", name=observer.name - ) - observers.append(SynapseFileLogObserver(observer.location)) - elif observer.type == DrainType.CONSOLE_JSON: - logger.debug( - "Starting up the {name} JSON console logger drain", name=observer.name - ) - observers.append(jsonFileLogObserver(observer.location)) - elif observer.type == DrainType.CONSOLE_JSON_TERSE: - logger.debug( - "Starting up the {name} terse JSON console logger drain", - name=observer.name, - ) - observers.append( - TerseJSONToConsoleLogObserver(observer.location, metadata={}) - ) - - # File drains - elif observer.type == DrainType.FILE: - logger.debug("Starting up the {name} file logger drain", name=observer.name) - log_file = open(observer.location, "at", buffering=1, encoding="utf8") - observers.append(SynapseFileLogObserver(log_file)) - elif observer.type == DrainType.FILE_JSON: - logger.debug( - "Starting up the {name} JSON file logger drain", name=observer.name - ) - log_file = open(observer.location, "at", buffering=1, encoding="utf8") - observers.append(jsonFileLogObserver(log_file)) - - elif observer.type == DrainType.NETWORK_JSON_TERSE: - metadata = {"server_name": hs.config.server_name} - log_observer = TerseJSONToTCPLogObserver( - hs=hs, - host=observer.location[0], - port=observer.location[1], - metadata=metadata, - maximum_buffer=observer.options.maximum_buffer, - ) - log_observer.start() - observers.append(log_observer) - else: - # We should never get here, but, just in case, throw an error. - raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - - publisher = StoppableLogPublisher(*observers) - log_filter = LogLevelFilterPredicate() - - for namespace, namespace_config in log_config.get( - "loggers", DEFAULT_LOGGERS - ).items(): - # Set the log level for twisted.logger.Logger namespaces - log_filter.setLogLevelForNamespace( - namespace, - stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")), - ) - - # Also set the log levels for the stdlib logger namespaces, to prevent - # them getting to PythonStdlibToTwistedLogger and having to be formatted - if "level" in namespace_config: - logging.getLogger(namespace).setLevel(namespace_config.get("level")) - - f = FilteringLogObserver(publisher, [log_filter]) - lco = LogContextObserver(f) - - if redirect_stdlib_logging: - stuff_into_twisted = PythonStdlibToTwistedLogger(lco) - stdliblogger = logging.getLogger() - stdliblogger.addHandler(stuff_into_twisted) - - # Always redirect standard I/O, otherwise other logging outputs might miss - # it. - logBeginner.beginLoggingTo([lco], redirectStandardIO=True) + new_config = { + "version": 1, + "formatters": { + "json": {"class": "synapse.logging.JsonFormatter"}, + "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, + }, + "handlers": {}, + "loggers": log_config.get("loggers", DEFAULT_LOGGERS), + "root": {"handlers": []}, + } - return publisher + for handler_name, handler in parse_drain_configs(log_config["drains"]): + new_config["handlers"][handler_name] = handler + # Add each handler to the root logger. + new_config["root"]["handlers"].append(handler_name) -def reload_structured_logging(*args, log_config=None) -> None: - warnings.warn( - "Currently the structured logging system can not be reloaded, doing nothing" - ) + return new_config diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 9b46956ca9..2fbf5549a1 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -16,141 +16,65 @@ """ Log formatters that output terse JSON. """ - import json -from typing import IO - -from twisted.logger import FileLogObserver - -from synapse.logging._remote import TCPLogObserver +import logging _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) - -def flatten_event(event: dict, metadata: dict, include_time: bool = False): - """ - Flatten a Twisted logging event to an dictionary capable of being sent - as a log event to a logging aggregation system. - - The format is vastly simplified and is not designed to be a "human readable - string" in the sense that traditional logs are. Instead, the structure is - optimised for searchability and filtering, with human-understandable log - keys. - - Args: - event (dict): The Twisted logging event we are flattening. - metadata (dict): Additional data to include with each log message. This - can be information like the server name. Since the target log - consumer does not know who we are other than by host IP, this - allows us to forward through static information. - include_time (bool): Should we include the `time` key? If False, the - event time is stripped from the event. - """ - new_event = {} - - # If it's a failure, make the new event's log_failure be the traceback text. - if "log_failure" in event: - new_event["log_failure"] = event["log_failure"].getTraceback() - - # If it's a warning, copy over a string representation of the warning. - if "warning" in event: - new_event["warning"] = str(event["warning"]) - - # Stdlib logging events have "log_text" as their human-readable portion, - # Twisted ones have "log_format". For now, include the log_format, so that - # context only given in the log format (e.g. what is being logged) is - # available. - if "log_text" in event: - new_event["log"] = event["log_text"] - else: - new_event["log"] = event["log_format"] - - # We want to include the timestamp when forwarding over the network, but - # exclude it when we are writing to stdout. This is because the log ingester - # (e.g. logstash, fluentd) can add its own timestamp. - if include_time: - new_event["time"] = round(event["log_time"], 2) - - # Convert the log level to a textual representation. - new_event["level"] = event["log_level"].name.upper() - - # Ignore these keys, and do not transfer them over to the new log object. - # They are either useless (isError), transferred manually above (log_time, - # log_level, etc), or contain Python objects which are not useful for output - # (log_logger, log_source). - keys_to_delete = [ - "isError", - "log_failure", - "log_format", - "log_level", - "log_logger", - "log_source", - "log_system", - "log_time", - "log_text", - "observer", - "warning", - ] - - # If it's from the Twisted legacy logger (twisted.python.log), it adds some - # more keys we want to purge. - if event.get("log_namespace") == "log_legacy": - keys_to_delete.extend(["message", "system", "time"]) - - # Rather than modify the dictionary in place, construct a new one with only - # the content we want. The original event should be considered 'frozen'. - for key in event.keys(): - - if key in keys_to_delete: - continue - - if isinstance(event[key], (str, int, bool, float)) or event[key] is None: - # If it's a plain type, include it as is. - new_event[key] = event[key] - else: - # If it's not one of those basic types, write out a string - # representation. This should probably be a warning in development, - # so that we are sure we are only outputting useful data. - new_event[key] = str(event[key]) - - # Add the metadata information to the event (e.g. the server_name). - new_event.update(metadata) - - return new_event - - -def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver: - """ - A log observer that formats events to a flattened JSON representation. - - Args: - outFile: The file object to write to. - metadata: Metadata to be added to each log object. - """ - - def formatEvent(_event: dict) -> str: - flattened = flatten_event(_event, metadata) - return _encoder.encode(flattened) + "\n" - - return FileLogObserver(outFile, formatEvent) - - -def TerseJSONToTCPLogObserver( - hs, host: str, port: int, metadata: dict, maximum_buffer: int -) -> FileLogObserver: - """ - A log observer that formats events to a flattened JSON representation. - - Args: - hs (HomeServer): The homeserver that is being logged for. - host: The host of the logging target. - port: The logging target's port. - metadata: Metadata to be added to each log object. - maximum_buffer: The maximum buffer size. - """ - - def formatEvent(_event: dict) -> str: - flattened = flatten_event(_event, metadata, include_time=True) - return _encoder.encode(flattened) + "\n" - - return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer) +# The properties of a standard LogRecord. +_LOG_RECORD_ATTRIBUTES = { + "args", + "asctime", + "created", + "exc_info", + # exc_text isn't a public attribute, but is used to cache the result of formatException. + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "message", + "module", + "msecs", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "thread", + "threadName", +} + + +class JsonFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + event = { + "log": record.getMessage(), + "namespace": record.name, + "level": record.levelname, + } + + return self._format(record, event) + + def _format(self, record: logging.LogRecord, event: dict) -> str: + # Add any extra attributes to the event. + for key, value in record.__dict__.items(): + if key not in _LOG_RECORD_ATTRIBUTES: + event[key] = value + + return _encoder.encode(event) + + +class TerseJsonFormatter(JsonFormatter): + def format(self, record: logging.LogRecord) -> str: + event = { + "log": record.getMessage(), + "namespace": record.name, + "level": record.levelname, + "time": round(record.created, 2), + } + + return self._format(record, event) diff --git a/synapse/logging/filter.py b/synapse/logging/filter.py new file mode 100644 index 0000000000..1baf8dd679 --- /dev/null +++ b/synapse/logging/filter.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from typing_extensions import Literal + + +class MetadataFilter(logging.Filter): + """Logging filter that adds constant values to each record. + + Args: + metadata: Key-value pairs to add to each record. + """ + + def __init__(self, metadata: dict): + self._metadata = metadata + + def filter(self, record: logging.LogRecord) -> Literal[True]: + for key, value in self._metadata.items(): + setattr(record, key, value) + return True diff --git a/synmark/__init__.py b/synmark/__init__.py index 09bc7e7927..3d4ec3e184 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -21,45 +21,6 @@ except ImportError: from twisted.internet.pollreactor import PollReactor as Reactor from twisted.internet.main import installReactor -from synapse.config.homeserver import HomeServerConfig -from synapse.util import Clock - -from tests.utils import default_config, setup_test_homeserver - - -async def make_homeserver(reactor, config=None): - """ - Make a Homeserver suitable for running benchmarks against. - - Args: - reactor: A Twisted reactor to run under. - config: A HomeServerConfig to use, or None. - """ - cleanup_tasks = [] - clock = Clock(reactor) - - if not config: - config = default_config("test") - - config_obj = HomeServerConfig() - config_obj.parse_config_dict(config, "", "") - - hs = setup_test_homeserver( - cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock - ) - stor = hs.get_datastore() - - # Run the database background updates. - if hasattr(stor.db_pool.updates, "do_next_background_update"): - while not await stor.db_pool.updates.has_completed_background_updates(): - await stor.db_pool.updates.do_next_background_update(1) - - def cleanup(): - for i in cleanup_tasks: - i() - - return hs, clock.sleep, cleanup - def make_reactor(): """ diff --git a/synmark/__main__.py b/synmark/__main__.py index 17df9ddeb7..de13c1a909 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -12,20 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import sys from argparse import REMAINDER from contextlib import redirect_stderr from io import StringIO import pyperf -from synmark import make_reactor -from synmark.suites import SUITES from twisted.internet.defer import Deferred, ensureDeferred from twisted.logger import globalLogBeginner, textFileLogObserver from twisted.python.failure import Failure +from synmark import make_reactor +from synmark.suites import SUITES + from tests.utils import setupdb diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index d8e4c7d58f..c9d9cf761e 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from io import StringIO from mock import Mock from pyperf import perf_counter -from synmark import make_homeserver from twisted.internet.defer import Deferred from twisted.internet.protocol import ServerFactory -from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.logger import LogBeginner, LogPublisher from twisted.protocols.basic import LineOnlyReceiver -from synapse.logging._structured import setup_structured_logging +from synapse.config.logger import _setup_stdlib_logging +from synapse.logging import RemoteHandler +from synapse.util import Clock class LineCounter(LineOnlyReceiver): @@ -62,7 +64,15 @@ async def main(reactor, loops): logger_factory.on_done = Deferred() port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") - hs, wait, cleanup = await make_homeserver(reactor) + # A fake homeserver config. + class Config: + server_name = "synmark-" + str(loops) + no_redirect_stdio = True + + hs_config = Config() + + # To be able to sleep. + clock = Clock(reactor) errors = StringIO() publisher = LogPublisher() @@ -72,47 +82,49 @@ async def main(reactor, loops): ) log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { + "version": 1, + "loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}}, + "formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}}, + "handlers": { "tersejson": { - "type": "network_json_terse", + "class": "synapse.logging.RemoteHandler", "host": "127.0.0.1", "port": port.getHost().port, "maximum_buffer": 100, + "_reactor": reactor, } }, } - logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) - logging_system = setup_structured_logging( - hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + logger = logging.getLogger("synapse.logging.test_terse_json") + _setup_stdlib_logging( + hs_config, log_config, logBeginner=beginner, ) # Wait for it to connect... - await logging_system._observers[0]._service.whenConnected() + for handler in logging.getLogger("synapse").handlers: + if isinstance(handler, RemoteHandler): + break + else: + raise RuntimeError("Improperly configured: no RemoteHandler found.") + + await handler._service.whenConnected() start = perf_counter() # Send a bunch of useful messages for i in range(0, loops): - logger.info("test message %s" % (i,)) - - if ( - len(logging_system._observers[0]._buffer) - == logging_system._observers[0].maximum_buffer - ): - while ( - len(logging_system._observers[0]._buffer) - > logging_system._observers[0].maximum_buffer / 2 - ): - await wait(0.01) + logger.info("test message %s", i) + + if len(handler._buffer) == handler.maximum_buffer: + while len(handler._buffer) > handler.maximum_buffer / 2: + await clock.sleep(0.01) await logger_factory.on_done end = perf_counter() - start - logging_system.stop() + handler.close() port.stopListening() - cleanup() return end diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py index e69de29bb2..a58d51441c 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + + +class LoggerCleanupMixin: + def get_logger(self, handler): + """ + Attach a handler to a logger and add clean-ups to remove revert this. + """ + # Create a logger and add the handler to it. + logger = logging.getLogger(__name__) + logger.addHandler(handler) + + # Ensure the logger actually logs something. + logger.setLevel(logging.INFO) + + # Ensure the logger gets cleaned-up appropriately. + self.addCleanup(logger.removeHandler, handler) + self.addCleanup(logger.setLevel, logging.NOTSET) + + return logger diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py new file mode 100644 index 0000000000..58ee1f2f3c --- /dev/null +++ b/tests/logging/test_remote_handler.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import AccumulatingProtocol + +from synapse.logging import RemoteHandler + +from tests.logging import LoggerCleanupMixin +from tests.server import FakeTransport, get_clock +from tests.unittest import TestCase + + +def connect_logging_client(reactor, client_id): + # This is essentially tests.server.connect_client, but disabling autoflush on + # the client transport. This is necessary to avoid an infinite loop due to + # sending of data via the logging transport causing additional logs to be + # written. + factory = reactor.tcpClients.pop(client_id)[2] + client = factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, reactor)) + client.makeConnection(FakeTransport(server, reactor, autoflush=False)) + + return client, server + + +class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.reactor, _ = get_clock() + + def test_log_output(self): + """ + The remote handler delivers logs over TCP. + """ + handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor) + logger = self.get_logger(handler) + + logger.info("Hello there, %s!", "wally") + + # Trigger the connection + client, server = connect_logging_client(self.reactor, 0) + + # Trigger data being sent + client.transport.flush() + + # One log message, with a single trailing newline + logs = server.data.decode("utf8").splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(server.data.count(b"\n"), 1) + + # Ensure the data passed through properly. + self.assertEqual(logs[0], "Hello there, wally!") + + def test_log_backpressure_debug(self): + """ + When backpressure is hit, DEBUG logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 7): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # Only the 7 infos made it through, the debugs were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 7) + self.assertNotIn(b"debug", server.data) + + def test_log_backpressure_info(self): + """ + When backpressure is hit, DEBUG and INFO logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 10): + logger.warning("warn %s" % (i,)) + + # Send a bunch of info messages + for i in range(0, 3): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The 10 warnings made it through, the debugs and infos were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 10) + self.assertNotIn(b"debug", server.data) + self.assertNotIn(b"info", server.data) + + def test_log_backpressure_cut_middle(self): + """ + When backpressure is hit, and no more DEBUG and INFOs cannot be culled, + it will cut the middle messages out. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a bunch of useful messages + for i in range(0, 20): + logger.warning("warn %s" % (i,)) + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The first five and last five warnings made it through, the debugs and + # infos were elided + logs = server.data.decode("utf8").splitlines() + self.assertEqual( + ["warn %s" % (i,) for i in range(5)] + + ["warn %s" % (i,) for i in range(15, 20)], + logs, + ) diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py deleted file mode 100644 index d36f5f426c..0000000000 --- a/tests/logging/test_structured.py +++ /dev/null @@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import os.path -import shutil -import sys -import textwrap - -from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile - -from synapse.config.logger import setup_logging -from synapse.logging._structured import setup_structured_logging -from synapse.logging.context import LoggingContext - -from tests.unittest import DEBUG, HomeserverTestCase - - -class FakeBeginner: - def beginLoggingTo(self, observers, **kwargs): - self.observers = observers - - -class StructuredLoggingTestBase: - """ - Test base that registers a cleanup handler to reset the stdlib log handler - to 'unset'. - """ - - def prepare(self, reactor, clock, hs): - def _cleanup(): - logging.getLogger("synapse").setLevel(logging.NOTSET) - - self.addCleanup(_cleanup) - - -class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase): - """ - Tests for Synapse's structured logging support. - """ - - def test_output_to_json_round_trip(self): - """ - Synapse logs can be outputted to JSON and then read back again. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json")) - - log_config = { - "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}} - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(json_log_file, "r") as f: - logged_events = list(eventsFromJSONLogFile(f)) - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertEqual( - eventAsText(logged_events[0], includeTimestamp=False), - "[tests.logging.test_structured#info] Hello there, wally!", - ) - - def test_output_to_text(self): - """ - Synapse logs can be outputted to text. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - log_file = os.path.abspath(os.path.join(temp_dir, "out.log")) - - log_config = {"drains": {"file": {"type": "file", "location": log_file}}} - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(log_file, "r") as f: - logged_events = f.read().strip().split("\n") - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertTrue( - logged_events[0].endswith( - " - tests.logging.test_structured - INFO - None - Hello there, wally!" - ) - ) - - def test_collects_logcontext(self): - """ - Test that log outputs have the attached logging context. - """ - log_config = {"drains": {}} - - # Begin the logger with our config - beginner = FakeBeginner() - publisher = setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logs = [] - - publisher.addObserver(logs.append) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - self.assertEqual(len(logs), 1) - self.assertEqual(logs[0]["request"], "somereq") - - -class StructuredLoggingConfigurationFileTestCase( - StructuredLoggingTestBase, HomeserverTestCase -): - def make_homeserver(self, reactor, clock): - - tempdir = self.mktemp() - os.mkdir(tempdir) - log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml")) - self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log")) - - config = self.default_config() - config["log_config"] = log_config_file - - with open(log_config_file, "w") as f: - f.write( - textwrap.dedent( - """\ - structured: true - - drains: - file: - type: file_json - location: %s - """ - % (self.homeserver_log,) - ) - ) - - self.addCleanup(self._sys_cleanup) - - return self.setup_test_homeserver(config=config) - - def _sys_cleanup(self): - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # Do not remove! We need the logging system to be set other than WARNING. - @DEBUG - def test_log_output(self): - """ - When a structured logging config is given, Synapse will use it. - """ - beginner = FakeBeginner() - publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner) - - # Make a logger and send an event - logger = Logger(namespace="tests.logging.test_structured", observer=publisher) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - with open(self.homeserver_log, "r") as f: - logged_events = [ - eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f) - ] - - logs = "\n".join(logged_events) - self.assertTrue("***** STARTING SERVER *****" in logs) - self.assertTrue("Hello there, steve!" in logs) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index fd128b88e0..73f469b802 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -14,57 +14,33 @@ # limitations under the License. import json -from collections import Counter +import logging +from io import StringIO -from twisted.logger import Logger +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter -from synapse.logging._structured import setup_structured_logging +from tests.logging import LoggerCleanupMixin +from tests.unittest import TestCase -from tests.server import connect_client -from tests.unittest import HomeserverTestCase -from .test_structured import FakeBeginner, StructuredLoggingTestBase - - -class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): - def test_log_output(self): +class TerseJsonTestCase(LoggerCleanupMixin, TestCase): + def test_terse_json_output(self): """ - The Terse JSON outputter delivers simplified structured logs over TCP. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - } - } - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logger = Logger( - namespace="tests.logging.test_terse_json", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Trigger the connection - self.pump() + output = StringIO() - _, server = connect_client(self.reactor, 0) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Trigger data being sent - self.pump() + logger.info("Hello there, %s!", "wally") - # One log message, with a single trailing newline - logs = server.data.decode("utf8").splitlines() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() self.assertEqual(len(logs), 1) - self.assertEqual(server.data.count(b"\n"), 1) - + self.assertEqual(data.count("\n"), 1) log = json.loads(logs[0]) # The terse logger should give us these keys. @@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): "log", "time", "level", - "log_namespace", - "request", - "scope", - "server_name", - "name", + "namespace", ] self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") - # It contains the data we expect. - self.assertEqual(log["name"], "wally") - - def test_log_backpressure_debug(self): + def test_extra_data(self): """ - When backpressure is hit, DEBUG logs will be shed. + Additional information can be included in the structured logging. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + output = StringIO() - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 7): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - _, server = connect_client(self.reactor, 0) - self.pump() - - # Only the 7 infos made it through, the debugs were elided - logs = server.data.splitlines() - self.assertEqual(len(logs), 7) - - def test_log_backpressure_info(self): - """ - When backpressure is hit, DEBUG and INFO logs will be shed. - """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + logger.info( + "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True} ) - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) - - # Send a bunch of useful messages - for i in range(0, 10): - logger.warn("test warn %s" % (i,)) - - # Send a bunch of info messages - for i in range(0, 3): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The 10 warnings made it through, the debugs and infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "time", + "level", + "namespace", + # The additional keys given via extra. + "foo", + "int", + "bool", + ] + self.assertCountEqual(log.keys(), expected_log_keys) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + # Check the values of the extra fields. + self.assertEqual(log["foo"], "bar") + self.assertEqual(log["int"], 3) + self.assertIs(log["bool"], True) - def test_log_backpressure_cut_middle(self): + def test_json_output(self): """ - When backpressure is hit, and no more DEBUG and INFOs cannot be culled, - it will cut the middle messages out. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + output = StringIO() - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + handler = logging.StreamHandler(output) + handler.setFormatter(JsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 20): - logger.warn("test warn", num=i) + logger.info("Hello there, %s!", "wally") - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The first five and last five warnings made it through, the debugs and - # infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) - self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs]) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") diff --git a/tests/server.py b/tests/server.py index ea9c22bc51..b97003fa5a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -571,12 +571,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol reactor factory: The connecting factory to build. """ - factory = reactor.tcpClients[client_id][2] + factory = reactor.tcpClients.pop(client_id)[2] client = factory.buildProtocol(None) server = AccumulatingProtocol() server.makeConnection(FakeTransport(client, reactor)) client.makeConnection(FakeTransport(server, reactor)) - reactor.tcpClients.pop(client_id) - return client, server -- cgit 1.5.1 From 46f4be94b410776ef3f922af2f437eb17631d2fa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Oct 2020 10:55:24 +0000 Subject: Fix race for concurrent downloads of remote media. (#8682) Fixes #6755 --- changelog.d/8682.bugfix | 1 + synapse/rest/media/v1/media_repository.py | 165 +++++++----- synapse/rest/media/v1/media_storage.py | 30 ++- synapse/storage/databases/main/media_repository.py | 27 ++ tests/replication/test_multi_media_repo.py | 277 +++++++++++++++++++++ tests/server.py | 2 +- 6 files changed, 431 insertions(+), 71 deletions(-) create mode 100644 changelog.d/8682.bugfix create mode 100644 tests/replication/test_multi_media_repo.py (limited to 'tests/server.py') diff --git a/changelog.d/8682.bugfix b/changelog.d/8682.bugfix new file mode 100644 index 0000000000..e61276aa05 --- /dev/null +++ b/changelog.d/8682.bugfix @@ -0,0 +1 @@ +Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 5cce7237a0..9cac74ebd8 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -305,15 +305,12 @@ class MediaRepository: # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. - if media_info: - file_id = media_info["filesystem_id"] - else: - file_id = random_string(24) - - file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() @@ -324,14 +321,34 @@ class MediaRepository: # Failed to find the file anywhere, lets download it. - media_info = await self._download_remote_file(server_name, media_id, file_id) + try: + media_info = await self._download_remote_file(server_name, media_id,) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file( - self, server_name: str, media_id: str, file_id: str - ) -> dict: + async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -346,6 +363,8 @@ class MediaRepository: The media info of the file. """ + file_id = random_string(24) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): @@ -401,22 +420,32 @@ class MediaRepository: await finish() - media_type = headers[b"Content-Type"][0].decode("ascii") - upload_name = get_filename_from_headers(headers) - time_now_ms = self.clock.time_msec() + media_type = headers[b"Content-Type"][0].decode("ascii") + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) logger.info("Stored remote media in file %r", fname) - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - media_info = { "media_type": media_type, "media_length": length, @@ -425,8 +454,6 @@ class MediaRepository: "filesystem_id": file_id, } - await self._generate_thumbnails(server_name, media_id, file_id, media_type) - return media_info def _get_thumbnail_requirements(self, media_type): @@ -692,42 +719,60 @@ class MediaRepository: if not t_byte_source: continue - try: - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, - url_cache=url_cache, - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - t_len = os.path.getsize(output_path) + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) - # Write to database - if server_name: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = await self.store.get_remote_media_thumbnail( + server_name, media_id, t_width, t_height, t_type, + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a9586fb0b7..268e0c8f50 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -52,6 +52,7 @@ class MediaStorage: storage_providers: Sequence["StorageProviderWrapper"], ): self.hs = hs + self.reactor = hs.get_reactor() self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers @@ -70,13 +71,16 @@ class MediaStorage: with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - await defer_to_thread( - self.hs.get_reactor(), _write_file_synchronously, source, f - ) + await self.write_to_file(source, f) await finish_cb() return fname + async def write_to_file(self, source: IO, output: IO): + """Asynchronously write the `source` to `output`. + """ + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + @contextlib.contextmanager def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as @@ -112,14 +116,20 @@ class MediaStorage: finished_called = [False] - async def finish(): - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - try: with open(fname, "wb") as f: + + async def finish(): + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + yield f, fname, finish except Exception: try: @@ -210,7 +220,7 @@ class MediaStorage: if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor() + open(local_path, "wb"), self.reactor ) await res.write_to_consumer(consumer) await consumer.wait() diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index daf57675d8..4b2f224718 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) + async def get_remote_media_thumbnail( + self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + ) -> Optional[Dict[str, Any]]: + """Fetch the thumbnail info of given width, height and type. + """ + + return await self.db_pool.simple_select_one( + table="remote_media_cache_thumbnails", + keyvalues={ + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": t_width, + "thumbnail_height": t_height, + "thumbnail_type": t_type, + }, + retcols=( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + allow_none=True, + desc="get_remote_media_thumbnail", + ) + async def store_remote_media_thumbnail( self, origin, diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py new file mode 100644 index 0000000000..77c261dbf7 --- /dev/null +++ b/tests/replication/test_multi_media_repo.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from binascii import unhexlify +from typing import Tuple + +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel +from twisted.web.server import Request + +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.server import HomeServer + +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, FakeTransport + +logger = logging.getLogger(__name__) + +test_server_connection_factory = None + + +class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "127.0.0.2" + + def default_config(self): + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + + request, channel = self.make_request( + "GET", + "/{}/{}".format(target, media_id), + shorthand=False, + access_token=self.access_token, + ) + request.render(hs.get_media_repository_resource().children[b"download"]) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_tls_protocol = _build_test_server(get_connection_factory()) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self): + """Test basic fetching of remote media from a single worker. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request.write(b"Hello!") + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"Hello!") + + def test_download_simple_file_race(self): + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"Hello!") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"Hello!") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self): + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + png_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request1.write(png_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], png_data) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request2.write(png_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], png_data) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[b"DNS:example.com"] + ) + return test_server_connection_factory + + +def _build_test_server(connection_creator): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=server_factory + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/server.py b/tests/server.py index b97003fa5a..3dd2cfc072 100644 --- a/tests/server.py +++ b/tests/server.py @@ -46,7 +46,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() - result = attr.ib(default=attr.Factory(dict)) + result = attr.ib(type=dict, default=attr.Factory(dict)) _producer = None @property -- cgit 1.5.1 From 9debe657a39a234d574e949ae8faf3f5ed027c09 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Nov 2020 22:39:09 +0000 Subject: pass a Site into make_request --- tests/rest/client/v1/utils.py | 31 +++++++++++++++++++++++++------ tests/server.py | 16 +++++++++++++++- tests/test_server.py | 40 +++++++++++++++++++++++++++------------- tests/unittest.py | 1 + 4 files changed, 68 insertions(+), 20 deletions(-) (limited to 'tests/server.py') diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index dc789fbdaa..60e4b9b846 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -27,7 +27,7 @@ from twisted.web.server import Site from synapse.api.constants import Membership -from tests.server import make_request, render +from tests.server import FakeSite, make_request, render @attr.s @@ -53,7 +53,11 @@ class RestHelper: path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(content).encode("utf8"), ) render(request, self.site.resource, self.hs.get_reactor()) @@ -126,7 +130,11 @@ class RestHelper: data.update(extra_data) request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(data).encode("utf8"), ) render(request, self.site.resource, self.hs.get_reactor()) @@ -159,7 +167,11 @@ class RestHelper: path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(content).encode("utf8"), ) render(request, self.site.resource, self.hs.get_reactor()) @@ -211,7 +223,9 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - request, channel = make_request(self.hs.get_reactor(), method, path, content) + request, channel = make_request( + self.hs.get_reactor(), self.site, method, path, content + ) render(request, self.site.resource, self.hs.get_reactor()) @@ -297,7 +311,12 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) request, channel = make_request( - self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok + self.hs.get_reactor(), + FakeSite(resource), + "POST", + path, + content=image_data, + access_token=tok, ) request.requestHeaders.addRawHeader( b"Content-Length", str(image_length).encode("UTF-8") diff --git a/tests/server.py b/tests/server.py index 3dd2cfc072..b9ccde4962 100644 --- a/tests/server.py +++ b/tests/server.py @@ -21,6 +21,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.resource import IResource from twisted.web.server import Site from synapse.http.site import SynapseRequest @@ -128,9 +129,21 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") + def __init__(self, resource: IResource): + """ + + Args: + resource: the resource to be used for rendering all requests + """ + self._resource = resource + + def getResourceFor(self, request): + return self._resource + def make_request( reactor, + site: Site, method, path, content=b"", @@ -145,6 +158,8 @@ def make_request( content, and return the Request and the Channel underneath. Args: + site: The twisted Site to associate with the Channel + method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). @@ -181,7 +196,6 @@ def make_request( if isinstance(content, str): content = content.encode("utf8") - site = FakeSite() channel = FakeChannel(site, reactor) req = request(channel) diff --git a/tests/test_server.py b/tests/test_server.py index 655c918a15..300d13ac95 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -26,6 +26,7 @@ from synapse.util import Clock from tests import unittest from tests.server import ( + FakeSite, ThreadedMemoryReactorClock, make_request, render, @@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase): ) request, channel = make_request( - self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" + self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) render(request, res, self.reactor) @@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") + request, channel = make_request( + self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" + ) render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"500") @@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") + request, channel = make_request( + self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" + ) render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"500") @@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") + request, channel = make_request( + self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" + ) render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"403") @@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") + request, channel = make_request( + self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar" + ) render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"400") @@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") + request, channel = make_request( + self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo" + ) render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"200") @@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase): def _make_request(self, method, path): """Create a request from the method/path and return a channel with the response.""" - request, channel = make_request(self.reactor, method, path, shorthand=False) - request.prepath = [] # This doesn't get set properly by make_request. - # Create a site and query for the resource. site = SynapseSite( "test", @@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase): self.resource, "1.0", ) + + request, channel = make_request( + self.reactor, site, method, path, shorthand=False + ) + request.prepath = [] # This doesn't get set properly by make_request. + request.site = site resource = site.getResourceFor(request) @@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") + request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"200") @@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") + request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"301") @@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") + request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"304") @@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"HEAD", b"/path") + request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"200") diff --git a/tests/unittest.py b/tests/unittest.py index 0a24c2f6b2..8c7979a7c0 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -434,6 +434,7 @@ class HomeserverTestCase(TestCase): return make_request( self.reactor, + self.site, method, path, content, -- cgit 1.5.1 From 70c0d47989b7794766ea957369c77d99664429c5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Nov 2020 23:48:25 +0000 Subject: fix dict handling for make_request() --- tests/server.py | 2 ++ tests/unittest.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'tests/server.py') diff --git a/tests/server.py b/tests/server.py index b9ccde4962..a74fb3fc67 100644 --- a/tests/server.py +++ b/tests/server.py @@ -193,6 +193,8 @@ def make_request( if not path.startswith(b"/"): path = b"/" + path + if isinstance(content, dict): + content = json.dumps(content).encode("utf8") if isinstance(content, str): content = content.encode("utf8") diff --git a/tests/unittest.py b/tests/unittest.py index 8c7979a7c0..3e656b7b12 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -429,9 +429,6 @@ class HomeserverTestCase(TestCase): Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ - if isinstance(content, dict): - content = json.dumps(content).encode("utf8") - return make_request( self.reactor, self.site, -- cgit 1.5.1 From ebc405446e6615d6187a2e29cb33f27dd5bd0841 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Nov 2020 14:45:22 +0000 Subject: Add a `custom_headers` param to `make_request` (#8760) Some tests want to set some custom HTTP request headers, so provide a way to do that before calling requestReceived(). --- changelog.d/8760.misc | 1 + tests/rest/client/v1/utils.py | 10 ++++++---- tests/server.py | 11 ++++++++++- tests/storage/test_client_ips.py | 13 +++++++------ 4 files changed, 24 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8760.misc (limited to 'tests/server.py') diff --git a/changelog.d/8760.misc b/changelog.d/8760.misc new file mode 100644 index 0000000000..54502e9b90 --- /dev/null +++ b/changelog.d/8760.misc @@ -0,0 +1 @@ +Refactor test utilities for injecting HTTP requests. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index afaf9f7b85..1b2d0497a6 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -296,10 +296,12 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) request, channel = make_request( - self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok - ) - request.requestHeaders.addRawHeader( - b"Content-Length", str(image_length).encode("UTF-8") + self.hs.get_reactor(), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], ) request.render(resource) self.hs.get_reactor().pump([100]) diff --git a/tests/server.py b/tests/server.py index 3dd2cfc072..ef03109a6c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable +from typing import Callable, Iterable, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -139,6 +139,9 @@ def make_request( shorthand=True, federation_auth_origin=None, content_is_form=False, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): """ Make a web request using the given method and path, feed it the @@ -157,6 +160,8 @@ def make_request( content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + custom_headers: (name, value) pairs to add as request headers + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ @@ -211,6 +216,10 @@ def make_request( # Assume the body is JSON req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + if custom_headers: + for k, v in custom_headers: + req.requestHeaders.addRawHeader(k, v) + req.requestReceived(method, path, b"1.1") return req, channel diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e96ca1c8ca..efca43ec78 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.server import make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -408,17 +409,17 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): # Advance to a known time self.reactor.advance(123456 - self.reactor.seconds()) - request, channel = self.make_request( + headers1 = {b"User-Agent": b"Mozzila pizza"} + headers1.update(headers) + + request, channel = make_request( + self.reactor, "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, + custom_headers=headers1.items(), **make_request_args, ) - request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") - - # Add the optional headers - for h, v in headers.items(): - request.requestHeaders.addRawHeader(h, v) self.render(request) # Advance so the save loop occurs -- cgit 1.5.1 From f125895475aeee9447f3988ecbd8bfd1836545bf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Nov 2020 18:21:47 +0000 Subject: Move `wait_until_result` into `FakeChannel` (#8758) FakeChannel has everything we need, and this more accurately models the real flow. --- changelog.d/8758.misc | 1 + tests/rest/key/v2/test_remote_key_resource.py | 6 ++-- tests/server.py | 42 +++++++++++++-------------- 3 files changed, 24 insertions(+), 25 deletions(-) create mode 100644 changelog.d/8758.misc (limited to 'tests/server.py') diff --git a/changelog.d/8758.misc b/changelog.d/8758.misc new file mode 100644 index 0000000000..54502e9b90 --- /dev/null +++ b/changelog.d/8758.misc @@ -0,0 +1 @@ +Refactor test utilities for injecting HTTP requests. diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 6671cbd32d..fbcf8d5b86 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string from tests import unittest -from tests.server import FakeChannel, wait_until_result +from tests.server import FakeChannel from tests.utils import default_config @@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): % (server_name.encode("utf-8"), key_id.encode("utf-8")), b"1.1", ) - wait_until_result(self.reactor, req) + channel.await_result() self.assertEqual(channel.code, 200) resp = channel.json_body return resp @@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): req.requestReceived( b"POST", path.encode("utf-8"), b"1.1", ) - wait_until_result(self.reactor, req) + channel.await_result() self.assertEqual(channel.code, 200) resp = channel.json_body return resp diff --git a/tests/server.py b/tests/server.py index ef03109a6c..18cb8b2d72 100644 --- a/tests/server.py +++ b/tests/server.py @@ -117,6 +117,25 @@ class FakeChannel: def transport(self): return self + def await_result(self, timeout: int = 100) -> None: + """ + Wait until the request is finished. + """ + self._reactor.run() + x = 0 + + while not self.result.get("done"): + # If there's a producer, tell it to resume producing so we get content + if self._producer: + self._producer.resumeProducing() + + x += 1 + + if x > timeout: + raise TimedOutException("Timed out waiting for request to finish.") + + self._reactor.advance(0.1) + class FakeSite: """ @@ -225,30 +244,9 @@ def make_request( return req, channel -def wait_until_result(clock, request, timeout=100): - """ - Wait until the request is finished. - """ - clock.run() - x = 0 - - while not request.finished: - - # If there's a producer, tell it to resume producing so we get content - if request._channel._producer: - request._channel._producer.resumeProducing() - - x += 1 - - if x > timeout: - raise TimedOutException("Timed out waiting for request to finish.") - - clock.advance(0.1) - - def render(request, resource, clock): request.render(resource) - wait_until_result(clock, request) + request._channel.await_result() @implementer(IReactorPluggableNameResolver) -- cgit 1.5.1 From 129ae841e5aebb34a980dd7d118140d08b0ff81d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Sun, 15 Nov 2020 22:47:54 +0000 Subject: Make `make_request` actually render the request remove the stubbing out of `request.process`, so that `requestReceived` also renders the request via the appropriate resource. Replace render() with a stub for now. --- tests/replication/test_multi_media_repo.py | 4 +- tests/rest/admin/test_admin.py | 6 --- tests/rest/admin/test_media.py | 6 --- tests/rest/client/v1/utils.py | 4 +- tests/rest/client/v2_alpha/test_account.py | 4 -- tests/rest/media/v1/test_media_storage.py | 7 ++-- tests/rest/media/v1/test_url_preview.py | 66 ++++++++++++++---------------- tests/server.py | 20 +++++---- tests/unittest.py | 8 ++++ 9 files changed, 57 insertions(+), 68 deletions(-) (limited to 'tests/server.py') diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index a9ac4aeec1..48b574ccbe 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -68,15 +68,15 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): the media which the caller should respond to. """ resource = hs.get_media_repository_resource().children[b"download"] - request, channel = make_request( + _, channel = make_request( self.reactor, FakeSite(resource), "GET", "/{}/{}".format(target, media_id), shorthand=False, access_token=self.access_token, + await_result=False, ) - request.render(resource) self.pump() clients = self.reactor.tcpClients diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9e4b0bca53..961a5732b3 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -231,8 +231,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): shorthand=False, access_token=admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Should be quarantined self.assertEqual( @@ -301,8 +299,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): shorthand=False, access_token=non_admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Should be successful self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -478,8 +474,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): shorthand=False, access_token=non_admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Shouldn't be quarantined self.assertEqual( diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 36e07f1b36..64b7aa53ee 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -133,8 +133,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): shorthand=False, access_token=self.admin_user_tok, ) - request.render(download_resource) - self.pump(1.0) # Should be successful self.assertEqual( @@ -172,8 +170,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): shorthand=False, access_token=self.admin_user_tok, ) - request.render(download_resource) - self.pump(1.0) self.assertEqual( 404, channel.code, @@ -548,8 +544,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): shorthand=False, access_token=self.admin_user_tok, ) - request.render(download_resource) - self.pump(1.0) if expect_success: self.assertEqual( diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 900852f85b..040a92d6f0 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -310,7 +310,7 @@ class RestHelper: """ image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), FakeSite(resource), "POST", @@ -319,8 +319,6 @@ class RestHelper: access_token=tok, custom_headers=[(b"Content-Length", str(image_length))], ) - request.render(resource) - self.hs.get_reactor().pump([100]) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 94a627b0a6..b871200909 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -263,8 +263,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): path, shorthand=False, ) - request.render(self.submit_token_resource) - self.pump() self.assertEquals(200, channel.code, channel.result) @@ -288,8 +286,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - request.render(self.submit_token_resource) - self.pump() self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 0fd31a0096..2a3b2a8f27 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -234,8 +234,8 @@ class MediaRepoTests(unittest.HomeserverTestCase): "GET", self.media_id, shorthand=False, + await_result=False, ) - request.render(self.download_resource) self.pump() # We've made one fetch, to example.com, using the media URL, and asking @@ -330,8 +330,8 @@ class MediaRepoTests(unittest.HomeserverTestCase): "GET", self.media_id + params, shorthand=False, + await_result=False, ) - request.render(self.thumbnail_resource) self.pump() headers = { @@ -359,7 +359,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_NOT_FOUND", - "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']" - % method, + "error": "Not found [b'example.com', b'12345']", }, ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index e00ad61231..ccdc8c2ecf 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -140,9 +140,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False, + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -165,8 +167,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False ) - request.render(self.preview_url) - self.pump() # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -183,8 +183,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False ) - request.render(self.preview_url) - self.pump() # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -204,9 +202,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request, channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False, + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -237,9 +237,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request, channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -270,9 +272,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request, channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -301,9 +305,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -331,8 +337,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -354,8 +358,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( @@ -373,8 +375,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://192.168.1.1", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -394,8 +394,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://1.1.1.2", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 403) self.assertEqual( @@ -414,9 +412,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] request, channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -451,8 +451,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, @@ -473,8 +471,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -496,8 +492,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( @@ -515,8 +509,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "OPTIONS", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) @@ -528,9 +520,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Build and make a request to the server request, channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() # Extract Synapse's tcp client @@ -603,8 +597,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): "GET", "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -668,8 +662,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): "GET", "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) diff --git a/tests/server.py b/tests/server.py index 5a1583a3e7..de7cb1d8b3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -171,16 +171,18 @@ def make_request( shorthand=True, federation_auth_origin=None, content_is_form=False, + await_result: bool = True, custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, ): """ - Make a web request using the given method and path, feed it the - content, and return the Request and the Channel underneath. + Make a web request using the given method, path and content, and render it + + Returns the Request and the Channel underneath. Args: - site: The twisted Site to associate with the Channel + site: The twisted Site to use to render the request method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. @@ -196,6 +198,10 @@ def make_request( custom_headers: (name, value) pairs to add as request headers + await_result: whether to wait for the request to complete rendering. If true, + will pump the reactor until the the renderer tells the channel the request + is finished. + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ @@ -225,11 +231,9 @@ def make_request( channel = FakeChannel(site, reactor) req = request(channel) - req.process = lambda: b"" req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(SEEK_END) - req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( @@ -257,12 +261,14 @@ def make_request( req.requestReceived(method, path, b"1.1") + if await_result: + channel.await_result() + return req, channel def render(request, resource, clock): - request.render(resource) - request._channel.await_result() + pass @implementer(IReactorPluggableNameResolver) diff --git a/tests/unittest.py b/tests/unittest.py index e39cb8dec9..9c7eca3b6e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -377,6 +377,7 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, + await_result: bool = True, ) -> Tuple[SynapseRequest, FakeChannel]: ... @@ -391,6 +392,7 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, + await_result: bool = True, ) -> Tuple[T, FakeChannel]: ... @@ -404,6 +406,7 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, + await_result: bool = True, ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the @@ -422,6 +425,10 @@ class HomeserverTestCase(TestCase): content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + await_result: whether to wait for the request to complete rendering. If + true (the default), will pump the test reactor until the the renderer + tells the channel the request is finished. + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ @@ -436,6 +443,7 @@ class HomeserverTestCase(TestCase): shorthand, federation_auth_origin, content_is_form, + await_result, ) def render(self, request): -- cgit 1.5.1 From be8fa65d0baddcc0a64954e21d38a854e4ee00d7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Sun, 15 Nov 2020 22:49:21 +0000 Subject: Remove redundant calls to `render()` --- tests/app/test_frontend_proxy.py | 10 ++--- tests/app/test_openid_listener.py | 12 +++--- tests/http/test_additional_resource.py | 4 +- tests/replication/_base.py | 5 +-- tests/replication/test_client_reader_shard.py | 4 -- tests/replication/test_sharded_event_persister.py | 8 ---- tests/rest/client/test_consent.py | 6 +-- tests/rest/client/v1/utils.py | 16 +++----- tests/rest/client/v2_alpha/test_sync.py | 6 +-- tests/server.py | 5 --- tests/storage/test_client_ips.py | 3 +- tests/test_server.py | 49 +++++------------------ tests/unittest.py | 10 +---- 13 files changed, 32 insertions(+), 106 deletions(-) (limited to 'tests/server.py') diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 0bac7995e8..40abe9d72d 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -15,7 +15,7 @@ from synapse.app.generic_worker import GenericWorkerServer -from tests.server import make_request, render +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -56,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = make_request(self.reactor, site, "PUT", "presence/a/status") - render(request, resource, self.reactor) + _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 400 + unrecognised, because nothing is registered self.assertEqual(channel.code, 400) @@ -78,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = make_request(self.reactor, site, "PUT", "presence/a/status") - render(request, resource, self.reactor) + _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 401, because the stub servlet still checks authentication self.assertEqual(channel.code, 401) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 1292145890..ea3be95cf1 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -20,7 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from synapse.config.server import parse_listener_def -from tests.server import make_request, render +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -67,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = make_request( + _, channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 401) @@ -116,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = make_request( + _, channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 401) diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index e835512a41..05e9c449be 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -17,7 +17,7 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json -from tests.server import FakeSite, make_request, render +from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase @@ -47,7 +47,6 @@ class AdditionalResourceTests(HomeserverTestCase): resource = AdditionalResource(self.hs, handler) request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - render(request, resource, self.reactor) self.assertEqual(request.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) @@ -57,7 +56,6 @@ class AdditionalResourceTests(HomeserverTestCase): resource = AdditionalResource(self.hs, handler) request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - render(request, resource, self.reactor) self.assertEqual(request.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index bc56b13dcd..516db4c30a 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -36,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport, render +from tests.server import FakeTransport try: import hiredis @@ -347,9 +347,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config["worker_replication_http_port"] = "8765" return config - def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._hs_to_site[worker_hs].resource, self.reactor) - def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 90172bd377..96801db473 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -55,7 +55,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -69,7 +68,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. @@ -89,7 +87,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_1, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -104,7 +101,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_2, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 2820dd622f..77fc3856d5 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -183,7 +183,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): request, channel = make_request( self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token ) - self.render_on_worker(sync_hs, request) next_batch = channel.json_body["next_batch"] # We now gut wrench into the events stream MultiWriterIdGenerator on @@ -214,7 +213,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): "/sync?since={}".format(next_batch), access_token=access_token, ) - self.render_on_worker(sync_hs, request) # We should only see the new event and nothing else self.assertIn(room_id1, channel.json_body["rooms"]["join"]) @@ -245,7 +243,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): "/sync?since={}".format(vector_clock_token), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertNotIn(room_id1, channel.json_body["rooms"]["join"]) self.assertIn(room_id2, channel.json_body["rooms"]["join"]) @@ -271,7 +268,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): "/sync?since={}".format(next_batch), access_token=access_token, ) - self.render_on_worker(sync_hs, request) prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][ "prev_batch" @@ -292,7 +288,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertListEqual([], channel.json_body["chunk"]) # Paginating back on the second room should produce the first event @@ -306,7 +301,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertEqual( channel.json_body["chunk"][0]["event_id"], first_event_in_room2 @@ -322,7 +316,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertListEqual([], channel.json_body["chunk"]) request, channel = make_request( @@ -334,7 +327,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertEqual( channel.json_body["chunk"][0]["event_id"], first_event_in_room2 diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 2931859f25..e2e6a5e16d 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.consent import consent_resource from tests import unittest -from tests.server import FakeSite, make_request, render +from tests.server import FakeSite, make_request class ConsentResourceTestCase(unittest.HomeserverTestCase): @@ -64,7 +64,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): request, channel = make_request( self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) def test_accept_consent(self): @@ -91,7 +90,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented @@ -107,7 +105,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Fetch the consent page, to get the consent version -- it should have @@ -120,7 +117,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and check that it's the version we diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 040a92d6f0..b58768675b 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -27,7 +27,7 @@ from twisted.web.server import Site from synapse.api.constants import Membership -from tests.server import FakeSite, make_request, render +from tests.server import FakeSite, make_request @attr.s @@ -52,14 +52,13 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, "POST", path, json.dumps(content).encode("utf8"), ) - render(request, self.site.resource, self.hs.get_reactor()) assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id @@ -129,7 +128,7 @@ class RestHelper: data = {"membership": membership} data.update(extra_data) - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, "PUT", @@ -137,8 +136,6 @@ class RestHelper: json.dumps(data).encode("utf8"), ) - render(request, self.site.resource, self.hs.get_reactor()) - assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" % (expect_code, int(channel.result["code"]), channel.result["body"]) @@ -166,14 +163,13 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, "PUT", path, json.dumps(content).encode("utf8"), ) - render(request, self.site.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -223,12 +219,10 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, method, path, content ) - render(request, self.site.resource, self.hs.get_reactor()) - assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" % (expect_code, int(channel.result["code"]), channel.result["body"]) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e..f74d611943 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -320,10 +320,8 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing._reset() # Now it SHOULD fail as it never completes! - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.assertRaises(TimedOutException, self.render, request) + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) class UnreadMessagesTestCase(unittest.HomeserverTestCase): diff --git a/tests/server.py b/tests/server.py index de7cb1d8b3..a51ad0c14e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -19,7 +19,6 @@ from twisted.internet.interfaces import ( ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock -from twisted.web.http import unquote from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Site @@ -267,10 +266,6 @@ def make_request( return req, channel -def render(request, resource, clock): - pass - - @implementer(IReactorPluggableNameResolver) class ThreadedMemoryReactorClock(MemoryReactorClock): """ diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 583addb5b5..6bdde1a2ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): headers1 = {b"User-Agent": b"Mozzila pizza"} headers1.update(headers) - request, channel = make_request( + make_request( self.reactor, self.site, "GET", @@ -421,7 +421,6 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): custom_headers=headers1.items(), **make_request_args, ) - self.render(request) # Advance so the save loop occurs self.reactor.advance(100) diff --git a/tests/test_server.py b/tests/test_server.py index 300d13ac95..c387a85f2e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -29,7 +29,6 @@ from tests.server import ( FakeSite, ThreadedMemoryReactorClock, make_request, - render, setup_test_homeserver, ) @@ -65,7 +64,6 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request( self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) - render(request, res, self.reactor) self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) @@ -84,10 +82,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -111,10 +106,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -132,10 +124,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -157,10 +146,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( + _, channel = make_request( self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar" ) - render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -182,10 +170,7 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - request, channel = make_request( - self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -216,16 +201,8 @@ class OptionsResourceTests(unittest.TestCase): "1.0", ) - request, channel = make_request( - self.reactor, site, method, path, shorthand=False - ) - request.prepath = [] # This doesn't get set properly by make_request. - - request.site = site - resource = site.getResourceFor(request) - - # Finally, render the resource and return the channel. - render(request, resource, self.reactor) + # render the request and return the channel + _, channel = make_request(self.reactor, site, method, path, shorthand=False) return channel def test_unknown_options_request(self): @@ -298,8 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -317,8 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -339,8 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -359,8 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) diff --git a/tests/unittest.py b/tests/unittest.py index 9c7eca3b6e..8a49bb5262 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -48,13 +48,7 @@ from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import ( - FakeChannel, - get_clock, - make_request, - render, - setup_test_homeserver, -) +from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -454,7 +448,7 @@ class HomeserverTestCase(TestCase): Args: request (synapse.http.site.SynapseRequest): The request to render. """ - render(request, self.resource, self.reactor) + pass def setup_test_homeserver(self, *args, **kwargs): """ -- cgit 1.5.1