diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c17e3b2399..f7bea79b0d 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools:
```
# Install the dependencies
-pip install -e ".[lint]"
+pip install -e ".[lint,mypy]"
# Run the linter script
./scripts-dev/lint.sh
@@ -63,7 +63,7 @@ run-time:
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
```
-You can also provided the `-d` option, which will lint the files that have been
+You can also provide the `-d` option, which will lint the files that have been
changed since the last git commit. This will often be significantly faster than
linting the whole codebase.
diff --git a/INSTALL.md b/INSTALL.md
index 22f7b7c029..c6fcb3bd7f 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -57,7 +57,7 @@ light workloads.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
-- Python 3.5.2 or later, up to Python 3.8.
+- Python 3.5.2 or later, up to Python 3.9.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in Python but some of the libraries it uses are written in
diff --git a/README.rst b/README.rst
index d609b4b62e..59d5a4389b 100644
--- a/README.rst
+++ b/README.rst
@@ -256,9 +256,9 @@ directory of your choice::
Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
- virtualenv -p python3 env
- source env/bin/activate
- python -m pip install --no-use-pep517 -e ".[all]"
+ python3 -m venv ./env
+ source ./env/bin/activate
+ pip install -e ".[all,test]"
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
@@ -270,9 +270,9 @@ check that everything is installed as it should be::
This should end with a 'PASSED' result::
- Ran 143 tests in 0.601s
+ Ran 1266 tests in 643.930s
- PASSED (successes=143)
+ PASSED (skips=15, successes=1251)
Running the Integration Tests
=============================
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 5a68312217..960c2aeb2b 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,6 +75,22 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.23.0
+====================
+
+Structured logging configuration breaking changes
+-------------------------------------------------
+
+This release deprecates use of the ``structured: true`` logging configuration for
+structured logging. If your logging configuration contains ``structured: true``
+then it should be modified based on the `structured logging documentation
+<https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md>`_.
+
+The ``structured`` and ``drains`` logging options are now deprecated and should
+be replaced by standard logging configuration of ``handlers`` and ``formatters`.
+
+A future will release of Synapse will make using ``structured: true`` an error.
+
Upgrading to v1.22.0
====================
diff --git a/changelog.d/8559.misc b/changelog.d/8559.misc
new file mode 100644
index 0000000000..d7bd00964e
--- /dev/null
+++ b/changelog.d/8559.misc
@@ -0,0 +1 @@
+Optimise `/createRoom` with multiple invited users.
diff --git a/changelog.d/8595.misc b/changelog.d/8595.misc
new file mode 100644
index 0000000000..24fab65cda
--- /dev/null
+++ b/changelog.d/8595.misc
@@ -0,0 +1 @@
+Implement and use an @lru_cache decorator.
diff --git a/changelog.d/8607.feature b/changelog.d/8607.feature
new file mode 100644
index 0000000000..fef1eccb92
--- /dev/null
+++ b/changelog.d/8607.feature
@@ -0,0 +1 @@
+Support generating structured logs via the standard logging configuration.
diff --git a/changelog.d/8610.feature b/changelog.d/8610.feature
new file mode 100644
index 0000000000..ed8d926964
--- /dev/null
+++ b/changelog.d/8610.feature
@@ -0,0 +1 @@
+Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.
\ No newline at end of file
diff --git a/changelog.d/8616.misc b/changelog.d/8616.misc
new file mode 100644
index 0000000000..385b14063e
--- /dev/null
+++ b/changelog.d/8616.misc
@@ -0,0 +1 @@
+Change schema to support access tokens belonging to one user but granting access to another.
diff --git a/changelog.d/8633.misc b/changelog.d/8633.misc
new file mode 100644
index 0000000000..8e1d006b36
--- /dev/null
+++ b/changelog.d/8633.misc
@@ -0,0 +1 @@
+Run `mypy` as part of the lint.sh script.
diff --git a/changelog.d/8655.misc b/changelog.d/8655.misc
new file mode 100644
index 0000000000..b588bdd3e2
--- /dev/null
+++ b/changelog.d/8655.misc
@@ -0,0 +1 @@
+Add more type hints to the application services code.
diff --git a/changelog.d/8664.misc b/changelog.d/8664.misc
new file mode 100644
index 0000000000..278cf53adc
--- /dev/null
+++ b/changelog.d/8664.misc
@@ -0,0 +1 @@
+Tell Black to format code for Python 3.5.
diff --git a/changelog.d/8665.doc b/changelog.d/8665.doc
new file mode 100644
index 0000000000..3b75307dc5
--- /dev/null
+++ b/changelog.d/8665.doc
@@ -0,0 +1 @@
+Note support for Python 3.9.
diff --git a/changelog.d/8666.doc b/changelog.d/8666.doc
new file mode 100644
index 0000000000..dee86b4a26
--- /dev/null
+++ b/changelog.d/8666.doc
@@ -0,0 +1 @@
+Minor updates to docs on running tests.
diff --git a/changelog.d/8667.doc b/changelog.d/8667.doc
new file mode 100644
index 0000000000..422d697da6
--- /dev/null
+++ b/changelog.d/8667.doc
@@ -0,0 +1 @@
+Interlink prometheus/grafana documentation.
diff --git a/changelog.d/8669.misc b/changelog.d/8669.misc
new file mode 100644
index 0000000000..5228105cd3
--- /dev/null
+++ b/changelog.d/8669.misc
@@ -0,0 +1 @@
+Don't pull event from DB when handling replication traffic.
diff --git a/changelog.d/8671.misc b/changelog.d/8671.misc
new file mode 100644
index 0000000000..bef8dc425a
--- /dev/null
+++ b/changelog.d/8671.misc
@@ -0,0 +1 @@
+Abstract some invite-related code in preparation for landing knocking.
\ No newline at end of file
diff --git a/changelog.d/8676.bugfix b/changelog.d/8676.bugfix
new file mode 100644
index 0000000000..df16c72761
--- /dev/null
+++ b/changelog.d/8676.bugfix
@@ -0,0 +1 @@
+Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broken in v1.22.0.
diff --git a/changelog.d/8678.bugfix b/changelog.d/8678.bugfix
new file mode 100644
index 0000000000..0508d8f109
--- /dev/null
+++ b/changelog.d/8678.bugfix
@@ -0,0 +1 @@
+Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules.
diff --git a/changelog.d/8679.misc b/changelog.d/8679.misc
new file mode 100644
index 0000000000..662eced4cf
--- /dev/null
+++ b/changelog.d/8679.misc
@@ -0,0 +1 @@
+Clarify representation of events in logfiles.
diff --git a/changelog.d/8680.misc b/changelog.d/8680.misc
new file mode 100644
index 0000000000..2ca2975464
--- /dev/null
+++ b/changelog.d/8680.misc
@@ -0,0 +1 @@
+Don't require `hiredis` package to be installed to run unit tests.
diff --git a/changelog.d/8682.bugfix b/changelog.d/8682.bugfix
new file mode 100644
index 0000000000..e61276aa05
--- /dev/null
+++ b/changelog.d/8682.bugfix
@@ -0,0 +1 @@
+Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.
diff --git a/changelog.d/8684.misc b/changelog.d/8684.misc
new file mode 100644
index 0000000000..1d23d42926
--- /dev/null
+++ b/changelog.d/8684.misc
@@ -0,0 +1 @@
+Fix typing info on cache call signature to accept `on_invalidate`.
diff --git a/changelog.d/8685.feature b/changelog.d/8685.feature
new file mode 100644
index 0000000000..fef1eccb92
--- /dev/null
+++ b/changelog.d/8685.feature
@@ -0,0 +1 @@
+Support generating structured logs via the standard logging configuration.
diff --git a/changelog.d/8688.misc b/changelog.d/8688.misc
new file mode 100644
index 0000000000..bef8dc425a
--- /dev/null
+++ b/changelog.d/8688.misc
@@ -0,0 +1 @@
+Abstract some invite-related code in preparation for landing knocking.
\ No newline at end of file
diff --git a/changelog.d/8689.feature b/changelog.d/8689.feature
new file mode 100644
index 0000000000..ed8d926964
--- /dev/null
+++ b/changelog.d/8689.feature
@@ -0,0 +1 @@
+Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.
\ No newline at end of file
diff --git a/changelog.d/8690.misc b/changelog.d/8690.misc
new file mode 100644
index 0000000000..0f38ba1f5d
--- /dev/null
+++ b/changelog.d/8690.misc
@@ -0,0 +1 @@
+Fail tests if they do not await coroutines.
diff --git a/contrib/grafana/README.md b/contrib/grafana/README.md
index ca780d412e..4608793394 100644
--- a/contrib/grafana/README.md
+++ b/contrib/grafana/README.md
@@ -3,4 +3,4 @@
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
-3. Set up additional recording rules
+3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 636fc284e4..d4051d0257 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -611,3 +611,82 @@ The following parameters should be set in the URL:
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
- ``device_id`` - The device to delete.
+
+List all pushers
+================
+Gets information about all pushers for a specific ``user_id``.
+
+The API is::
+
+ GET /_synapse/admin/v1/users/<user_id>/pushers
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "pushers": [
+ {
+ "app_display_name":"HTTP Push Notifications",
+ "app_id":"m.http",
+ "data": {
+ "url":"example.com"
+ },
+ "device_display_name":"pushy push",
+ "kind":"http",
+ "lang":"None",
+ "profile_tag":"",
+ "pushkey":"a@example.com"
+ }
+ ],
+ "total": 1
+ }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``pushers`` - An array containing the current pushers for the user
+
+ - ``app_display_name`` - string - A string that will allow the user to identify
+ what application owns this pusher.
+
+ - ``app_id`` - string - This is a reverse-DNS style identifier for the application.
+ Max length, 64 chars.
+
+ - ``data`` - A dictionary of information for the pusher implementation itself.
+
+ - ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send
+ notifications to.
+
+ - ``format`` - string - The format to use when sending notifications to the
+ Push Gateway.
+
+ - ``device_display_name`` - string - A string that will allow the user to identify
+ what device owns this pusher.
+
+ - ``profile_tag`` - string - This string determines which set of device specific rules
+ this pusher executes.
+
+ - ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes.
+ - ``lang`` - string - The preferred language for receiving notifications
+ (e.g. 'en' or 'en-US')
+
+ - ``profile_tag`` - string - This string determines which set of device specific rules
+ this pusher executes.
+
+ - ``pushkey`` - string - This is a unique identifier for this pusher.
+ Max length, 512 bytes.
+
+- ``total`` - integer - Number of pushers.
+
+See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md
index b386ec91c1..fb71af4911 100644
--- a/docs/metrics-howto.md
+++ b/docs/metrics-howto.md
@@ -60,6 +60,8 @@
1. Restart Prometheus.
+1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
+
## Monitoring workers
To monitor a Synapse installation using
diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml
index e26657f9fe..ff3c747180 100644
--- a/docs/sample_log_config.yaml
+++ b/docs/sample_log_config.yaml
@@ -3,7 +3,11 @@
# This is a YAML file containing a standard Python logging configuration
# dictionary. See [1] for details on the valid settings.
#
+# Synapse also supports structured logging for machine readable logs which can
+# be ingested by ELK stacks. See [2] for details.
+#
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
+# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
version: 1
diff --git a/docs/structured_logging.md b/docs/structured_logging.md
index decec9b8fa..b1281667e0 100644
--- a/docs/structured_logging.md
+++ b/docs/structured_logging.md
@@ -1,83 +1,161 @@
# Structured Logging
-A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack".
+A structured logging system can be useful when your logs are destined for a
+machine to parse and process. By maintaining its machine-readable characteristics,
+it enables more efficient searching and aggregations when consumed by software
+such as the "ELK stack".
-Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to).
+Synapse's structured logging system is configured via the file that Synapse's
+`log_config` config option points to. The file should include a formatter which
+uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a
+handler which uses the above formatter.
+
+There is also a `synapse.logging.JsonFormatter` option which does not include
+a timestamp in the resulting JSON. This is useful if the log ingester adds its
+own timestamp.
A structured logging configuration looks similar to the following:
```yaml
-structured: true
+version: 1
+
+formatters:
+ structured:
+ class: synapse.logging.TerseJsonFormatter
+
+handlers:
+ file:
+ class: logging.handlers.TimedRotatingFileHandler
+ formatter: structured
+ filename: /path/to/my/logs/homeserver.log
+ when: midnight
+ backupCount: 3 # Does not include the current log file.
+ encoding: utf8
loggers:
synapse:
level: INFO
+ handlers: [remote]
synapse.storage.SQL:
level: WARNING
-
-drains:
- console:
- type: console
- location: stdout
- file:
- type: file_json
- location: homeserver.log
```
-The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON).
-
-## Drain Types
+The above logging config will set Synapse as 'INFO' logging level by default,
+with the SQL layer at 'WARNING', and will log to a file, stored as JSON.
-Drain types can be specified by the `type` key.
+It is also possible to figure Synapse to log to a remote endpoint by using the
+`synapse.logging.RemoteHandler` class included with Synapse. It takes the
+following arguments:
-### `console`
+- `host`: Hostname or IP address of the log aggregator.
+- `port`: Numerical port to contact on the host.
+- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow.
-Outputs human-readable logs to the console.
+A remote structured logging configuration looks similar to the following:
-Arguments:
+```yaml
+version: 1
-- `location`: Either `stdout` or `stderr`.
+formatters:
+ structured:
+ class: synapse.logging.TerseJsonFormatter
-### `console_json`
+handlers:
+ remote:
+ class: synapse.logging.RemoteHandler
+ formatter: structured
+ host: 10.1.2.3
+ port: 9999
-Outputs machine-readable JSON logs to the console.
+loggers:
+ synapse:
+ level: INFO
+ handlers: [remote]
+ synapse.storage.SQL:
+ level: WARNING
+```
-Arguments:
+The above logging config will set Synapse as 'INFO' logging level by default,
+with the SQL layer at 'WARNING', and will log JSON formatted messages to a
+remote endpoint at 10.1.2.3:9999.
-- `location`: Either `stdout` or `stderr`.
+## Upgrading from legacy structured logging configuration
-### `console_json_terse`
+Versions of Synapse prior to v1.23.0 included a custom structured logging
+configuration which is deprecated. It used a `structured: true` flag and
+configured `drains` instead of ``handlers`` and `formatters`.
-Outputs machine-readable JSON logs to the console, separated by newlines. This
-format is not designed to be read and re-formatted into human-readable text, but
-is optimal for a logging aggregation system.
+Synapse currently automatically converts the old configuration to the new
+configuration, but this will be removed in a future version of Synapse. The
+following reference can be used to update your configuration. Based on the drain
+`type`, we can pick a new handler:
-Arguments:
+1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
+ with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`
+ or `ext://sys.stderr` should be used.
+2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with
+ a location of the file path should be used.
+3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler`
+ with the host and port should be used.
-- `location`: Either `stdout` or `stderr`.
+Then based on the drain `type` we can pick a new formatter:
-### `file`
+1. For a type of `console` or `file` no formatter is necessary.
+2. For a type of `console_json` or `file_json`: a formatter of
+ `synapse.logging.JsonFormatter` should be used.
+3. For a type of `console_json_terse` or `network_json_terse`: a formatter of
+ `synapse.logging.TerseJsonFormatter` should be used.
-Outputs human-readable logs to a file.
+For each new handler and formatter they should be added to the logging configuration
+and then assigned to either a logger or the root logger.
-Arguments:
+An example legacy configuration:
-- `location`: An absolute path to the file to log to.
+```yaml
+structured: true
-### `file_json`
+loggers:
+ synapse:
+ level: INFO
+ synapse.storage.SQL:
+ level: WARNING
-Outputs machine-readable logs to a file.
+drains:
+ console:
+ type: console
+ location: stdout
+ file:
+ type: file_json
+ location: homeserver.log
+```
-Arguments:
+Would be converted into a new configuration:
-- `location`: An absolute path to the file to log to.
+```yaml
+version: 1
-### `network_json_terse`
+formatters:
+ json:
+ class: synapse.logging.JsonFormatter
-Delivers machine-readable JSON logs to a log aggregator over TCP. This is
-compatible with LogStash's TCP input with the codec set to `json_lines`.
+handlers:
+ console:
+ class: logging.StreamHandler
+ location: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: json
+ filename: homeserver.log
-Arguments:
+loggers:
+ synapse:
+ level: INFO
+ handlers: [console, file]
+ synapse.storage.SQL:
+ level: WARNING
+```
-- `host`: Hostname or IP address of the log aggregator.
-- `port`: Numerical port to contact on the host.
\ No newline at end of file
+The new logging configuration is a bit more verbose, but significantly more
+flexible. It allows for configuration that were not previously possible, such as
+sending plain logs over the network, or using different handlers for different
+modules.
diff --git a/mypy.ini b/mypy.ini
index 1fbd8decf8..1ece2ba082 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -57,6 +57,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
+ synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
@@ -82,6 +83,9 @@ ignore_missing_imports = True
[mypy-zope]
ignore_missing_imports = True
+[mypy-bcrypt]
+ignore_missing_imports = True
+
[mypy-constantly]
ignore_missing_imports = True
diff --git a/pyproject.toml b/pyproject.toml
index db4a2e41e4..cd880d4e39 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@
showcontent = true
[tool.black]
-target-version = ['py34']
+target-version = ['py35']
exclude = '''
(
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index f2b65a2105..f328ab57d5 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -80,7 +80,7 @@ else
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
- files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
+ files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
fi
fi
@@ -94,3 +94,4 @@ isort "${files[@]}"
python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh
flake8 "${files[@]}"
+mypy
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index a5b88731f1..5882f3a0b0 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -19,9 +19,10 @@ can crop up, e.g the cache descriptors.
from typing import Callable, Optional
+from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
-from mypy.types import CallableType
+from mypy.types import CallableType, NoneType
class SynapsePlugin(Plugin):
@@ -40,8 +41,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
It already has *almost* the correct signature, except:
- 1. the `self` argument needs to be marked as "bound"; and
- 2. any `cache_context` argument should be removed.
+ 1. the `self` argument needs to be marked as "bound";
+ 2. any `cache_context` argument should be removed;
+ 3. an optional keyword argument `on_invalidated` should be added.
"""
# First we mark this as a bound function signature.
@@ -58,19 +60,33 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
context_arg_index = idx
break
+ arg_types = list(signature.arg_types)
+ arg_names = list(signature.arg_names)
+ arg_kinds = list(signature.arg_kinds)
+
if context_arg_index:
- arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)
-
- arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)
-
- arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)
- signature = signature.copy_modified(
- arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
- )
+ # Third, we add an optional "on_invalidate" argument.
+ #
+ # This is a callable which accepts no input and returns nothing.
+ calltyp = CallableType(
+ arg_types=[],
+ arg_kinds=[],
+ arg_names=[],
+ ret_type=NoneType(),
+ fallback=ctx.api.named_generic_type("builtins.function", []),
+ )
+
+ arg_types.append(calltyp)
+ arg_names.append("on_invalidate")
+ arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
+
+ signature = signature.copy_modified(
+ arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
+ )
return signature
diff --git a/setup.py b/setup.py
index 2f4a3170d2..9730afb41b 100755
--- a/setup.py
+++ b/setup.py
@@ -131,6 +131,7 @@ setup(
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
],
scripts=["synctl"] + glob.glob("scripts/*"),
cmdclass={"test": TestCommand},
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 526cb58c5f..bfcaf68b2a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -190,10 +191,6 @@ class Auth:
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
- request.authenticated_entity = user_id
- opentracing.set_tag("authenticated_entity", user_id)
- opentracing.set_tag("appservice_id", app_service.id)
-
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@@ -203,31 +200,38 @@ class Auth:
device_id="dummy-device", # stubbed
)
- return synapse.types.create_requester(user_id, app_service=app_service)
+ requester = synapse.types.create_requester(
+ user_id, app_service=app_service
+ )
+
+ request.requester = user_id
+ opentracing.set_tag("authenticated_entity", user_id)
+ opentracing.set_tag("user_id", user_id)
+ opentracing.set_tag("appservice_id", app_service.id)
+
+ return requester
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
- user = user_info["user"]
- token_id = user_info["token_id"]
- is_guest = user_info["is_guest"]
- shadow_banned = user_info["shadow_banned"]
+ token_id = user_info.token_id
+ is_guest = user_info.is_guest
+ shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
- user_id = user.to_string()
- if await self.store.is_account_expired(user_id, self.clock.time_msec()):
+ if await self.store.is_account_expired(
+ user_info.user_id, self.clock.time_msec()
+ ):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
- # device_id may not be present if get_user_by_access_token has been
- # stubbed out.
- device_id = user_info.get("device_id")
+ device_id = user_info.device_id
- if user and access_token and ip_addr:
+ if access_token and ip_addr:
await self.store.insert_client_ip(
- user_id=user.to_string(),
+ user_id=user_info.token_owner,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@@ -241,19 +245,23 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
- request.authenticated_entity = user.to_string()
- opentracing.set_tag("authenticated_entity", user.to_string())
- if device_id:
- opentracing.set_tag("device_id", device_id)
-
- return synapse.types.create_requester(
- user,
+ requester = synapse.types.create_requester(
+ user_info.user_id,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
+ authenticated_entity=user_info.token_owner,
)
+
+ request.requester = requester
+ opentracing.set_tag("authenticated_entity", user_info.token_owner)
+ opentracing.set_tag("user_id", user_info.user_id)
+ if device_id:
+ opentracing.set_tag("device_id", device_id)
+
+ return requester
except KeyError:
raise MissingClientTokenError()
@@ -284,7 +292,7 @@ class Auth:
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
- ) -> dict:
+ ) -> TokenLookupResult:
""" Validate access token and get user_id from it
Args:
@@ -293,13 +301,7 @@ class Auth:
allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
- Returns:
- dict that includes:
- `user` (UserID)
- `is_guest` (bool)
- `shadow_banned` (bool)
- `token_id` (int|None): access token id. May be None if guest
- `device_id` (str|None): device corresponding to access token
+
Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
@@ -309,9 +311,9 @@ class Auth:
if rights == "access":
# first look in the database
- r = await self._look_up_user_by_access_token(token)
+ r = await self.store.get_user_by_access_token(token)
if r:
- valid_until_ms = r["valid_until_ms"]
+ valid_until_ms = r.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@@ -328,7 +330,6 @@ class Auth:
# otherwise it needs to be a valid macaroon
try:
user_id, guest = self._parse_and_validate_macaroon(token, rights)
- user = UserID.from_string(user_id)
if rights == "access":
if not guest:
@@ -354,23 +355,17 @@ class Auth:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
- ret = {
- "user": user,
- "is_guest": True,
- "shadow_banned": False,
- "token_id": None,
+
+ ret = TokenLookupResult(
+ user_id=user_id,
+ is_guest=True,
# all guests get the same device id
- "device_id": GUEST_DEVICE_ID,
- }
+ device_id=GUEST_DEVICE_ID,
+ )
elif rights == "delete_pusher":
# We don't store these tokens in the database
- ret = {
- "user": user,
- "is_guest": False,
- "shadow_banned": False,
- "token_id": None,
- "device_id": None,
- }
+
+ ret = TokenLookupResult(user_id=user_id, is_guest=False)
else:
raise RuntimeError("Unknown rights setting %s", rights)
return ret
@@ -479,31 +474,15 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
- async def _look_up_user_by_access_token(self, token):
- ret = await self.store.get_user_by_access_token(token)
- if not ret:
- return None
-
- # we use ret.get() below because *lots* of unit tests stub out
- # get_user_by_access_token in a way where it only returns a couple of
- # the fields.
- user_info = {
- "user": UserID.from_string(ret.get("name")),
- "token_id": ret.get("token_id", None),
- "is_guest": False,
- "shadow_banned": ret.get("shadow_banned"),
- "device_id": ret.get("device_id"),
- "valid_until_ms": ret.get("valid_until_ms"),
- }
- return user_info
-
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
- request.authenticated_entity = service.sender
+ request.requester = synapse.types.create_requester(
+ service.sender, app_service=service
+ )
return service
async def is_server_admin(self, user: UserID) -> bool:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 3862d9c08f..3944780a42 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationServiceApi
@@ -52,11 +52,11 @@ class ApplicationService:
self,
token,
hostname,
+ id,
+ sender,
url=None,
namespaces=None,
hs_token=None,
- sender=None,
- id=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
@@ -164,9 +164,9 @@ class ApplicationService:
does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
- @cached(num_args=1)
+ @cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(
- self, room_id: str, store: "DataStore"
+ self, room_id: str, store: "DataStore", cache_context: _CacheContext,
) -> bool:
"""Check if this service is interested a room based upon it's membership
@@ -177,7 +177,9 @@ class ApplicationService:
Returns:
True if this service would like to know about this room.
"""
- member_list = await store.get_users_in_room(room_id)
+ member_list = await store.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
# check joined member events
for user_id in member_list:
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 6b7be28aee..d4e887a3e0 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -23,7 +23,6 @@ from string import Template
import yaml
from twisted.logger import (
- ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@@ -32,11 +31,9 @@ from twisted.logger import (
import synapse
from synapse.app import _base as appbase
-from synapse.logging._structured import (
- reload_structured_logging,
- setup_structured_logging,
-)
+from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter
+from synapse.logging.filter import MetadataFilter
from synapse.util.versionstring import get_version_string
from ._base import Config, ConfigError
@@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template(
# This is a YAML file containing a standard Python logging configuration
# dictionary. See [1] for details on the valid settings.
#
+# Synapse also supports structured logging for machine readable logs which can
+# be ingested by ELK stacks. See [2] for details.
+#
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
+# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
version: 1
@@ -176,11 +177,11 @@ class LoggingConfig(Config):
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
-def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
+def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
"""
- Set up Python stdlib logging.
+ Set up Python standard library logging.
"""
- if log_config is None:
+ if log_config_path is None:
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
@@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
handler.setFormatter(formatter)
logger.addHandler(handler)
else:
- logging.config.dictConfig(log_config)
+ # Load the logging configuration.
+ _load_logging_config(log_config_path)
# We add a log record factory that runs all messages through the
# LoggingContextFilter so that we get the context *at the time we log*
@@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
- log_filter = LoggingContextFilter(request="")
+ log_context_filter = LoggingContextFilter(request="")
+ log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
- log_filter.filter(record)
+ log_context_filter.filter(record)
+ log_metadata_filter.filter(record)
return record
logging.setLogRecordFactory(factory)
@@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")
- return observer
-
-def _reload_stdlib_logging(*args, log_config=None):
- logger = logging.getLogger("")
+def _load_logging_config(log_config_path: str) -> None:
+ """
+ Configure logging from a log config path.
+ """
+ with open(log_config_path, "rb") as f:
+ log_config = yaml.safe_load(f.read())
if not log_config:
- logger.warning("Reloaded a blank config?")
+ logging.warning("Loaded a blank logging config?")
+
+ # If the old structured logging configuration is being used, convert it to
+ # the new style configuration.
+ if "structured" in log_config and log_config.get("structured"):
+ log_config = setup_structured_logging(log_config)
logging.config.dictConfig(log_config)
+def _reload_logging_config(log_config_path):
+ """
+ Reload the log configuration from the file and apply it.
+ """
+ # If no log config path was given, it cannot be reloaded.
+ if log_config_path is None:
+ return
+
+ _load_logging_config(log_config_path)
+ logging.info("Reloaded log config from %s due to SIGHUP", log_config_path)
+
+
def setup_logging(
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
-) -> ILogObserver:
+) -> None:
"""
Set up the logging subsystem.
@@ -282,41 +305,18 @@ def setup_logging(
logBeginner: The Twisted logBeginner to use.
- Returns:
- The "root" Twisted Logger observer, suitable for sending logs to from a
- Logger instance.
"""
- log_config = config.worker_log_config if use_worker_options else config.log_config
-
- def read_config(*args, callback=None):
- if log_config is None:
- return None
-
- with open(log_config, "rb") as f:
- log_config_body = yaml.safe_load(f.read())
-
- if callback:
- callback(log_config=log_config_body)
- logging.info("Reloaded log config from %s due to SIGHUP", log_config)
-
- return log_config_body
+ log_config_path = (
+ config.worker_log_config if use_worker_options else config.log_config
+ )
- log_config_body = read_config()
+ # Perform one-time logging configuration.
+ _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
+ # Add a SIGHUP handler to reload the logging configuration, if one is available.
+ appbase.register_sighup(_reload_logging_config, log_config_path)
- if log_config_body and log_config_body.get("structured") is True:
- logger = setup_structured_logging(
- hs, config, log_config_body, logBeginner=logBeginner
- )
- appbase.register_sighup(read_config, callback=reload_structured_logging)
- else:
- logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
- appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
-
- # make sure that the first thing we log is a thing we can grep backwards
- # for
+ # Log immediately so we can grep backwards.
logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
logging.info("Instance name: %s", hs.get_instance_name())
-
- return logger
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e203206865..8028663fa8 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -368,7 +368,7 @@ class FrozenEvent(EventBase):
return self.__repr__()
def __repr__(self):
- return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
+ return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
@@ -451,7 +451,7 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
- return "<%s event_id='%s', type='%s', state_key='%s'>" % (
+ return "<%s event_id=%r, type=%r, state_key=%r>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3a6b95631e..a0933fae88 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -154,7 +154,7 @@ class Authenticator:
)
logger.debug("Request from %s", origin)
- request.authenticated_entity = origin
+ request.requester = origin
# If we get a valid signed request from the other side, its probably
# alive
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3ed29a2c16..9fc8444228 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
from prometheus_client import Counter
@@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
class ApplicationServicesHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api()
@@ -247,7 +250,9 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
- async def _handle_typing(self, service: ApplicationService, new_token: int):
+ async def _handle_typing(
+ self, service: ApplicationService, new_token: int
+ ) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
@@ -259,7 +264,7 @@ class ApplicationServicesHandler:
)
return typing
- async def _handle_receipts(self, service: ApplicationService):
+ async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
@@ -271,7 +276,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
- ):
+ ) -> List[JsonDict]:
events = [] # type: List[JsonDict]
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
@@ -301,11 +306,11 @@ class ApplicationServicesHandler:
return events
- async def query_user_exists(self, user_id):
+ async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
Args:
- user_id(str): The user to query if they exist on any AS.
+ user_id: The user to query if they exist on any AS.
Returns:
True if this user exists on at least one application service.
"""
@@ -316,11 +321,13 @@ class ApplicationServicesHandler:
return True
return False
- async def query_room_alias_exists(self, room_alias):
+ async def query_room_alias_exists(
+ self, room_alias: RoomAlias
+ ) -> Optional[RoomAliasMapping]:
"""Check if an application service knows this room alias exists.
Args:
- room_alias(RoomAlias): The room alias to query.
+ room_alias: The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
@@ -336,10 +343,13 @@ class ApplicationServicesHandler:
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = await self.store.get_association_from_room_alias(room_alias)
- return result
+ return await self.store.get_association_from_room_alias(room_alias)
+
+ return None
- async def query_3pe(self, kind, protocol, fields):
+ async def query_3pe(
+ self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+ ) -> List[JsonDict]:
services = self._get_services_for_3pn(protocol)
results = await make_deferred_yieldable(
@@ -361,7 +371,9 @@ class ApplicationServicesHandler:
return ret
- async def get_3pe_protocols(self, only_protocol=None):
+ async def get_3pe_protocols(
+ self, only_protocol: Optional[str] = None
+ ) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
@@ -379,7 +391,7 @@ class ApplicationServicesHandler:
if info is not None:
protocols[p].append(info)
- def _merge_instances(infos):
+ def _merge_instances(infos: List[JsonDict]) -> JsonDict:
if not infos:
return {}
@@ -394,19 +406,17 @@ class ApplicationServicesHandler:
return combined
- for p in protocols.keys():
- protocols[p] = _merge_instances(protocols[p])
+ return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
- return protocols
-
- async def _get_services_for_event(self, event):
+ async def _get_services_for_event(
+ self, event: EventBase
+ ) -> List[ApplicationService]:
"""Retrieve a list of application services interested in this event.
Args:
- event(Event): The event to check. Can be None if alias_list is not.
+ event: The event to check. Can be None if alias_list is not.
Returns:
- list<ApplicationService>: A list of services interested in this
- event based on the service regex.
+ A list of services interested in this event based on the service regex.
"""
services = self.store.get_app_services()
@@ -420,17 +430,15 @@ class ApplicationServicesHandler:
return interested_list
- def _get_services_for_user(self, user_id):
+ def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
services = self.store.get_app_services()
- interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
- return interested_list
+ return [s for s in services if (s.is_interested_in_user(user_id))]
- def _get_services_for_3pn(self, protocol):
+ def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
services = self.store.get_app_services()
- interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
- return interested_list
+ return [s for s in services if s.is_interested_in_protocol(protocol)]
- async def _is_unknown_user(self, user_id):
+ async def _is_unknown_user(self, user_id: str) -> bool:
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
@@ -445,9 +453,8 @@ class ApplicationServicesHandler:
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
- async def _check_user_exists(self, user_id):
+ async def _check_user_exists(self, user_id: str) -> bool:
unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
- exists = await self.query_user_exists(user_id)
- return exists
+ return await self.query_user_exists(user_id)
return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dd14ab69d7..ff103cbb92 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,10 +18,20 @@ import logging
import time
import unicodedata
import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
-import bcrypt # type: ignore[import]
+import bcrypt
import pymacaroons
from synapse.api.constants import LoginType
@@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
@@ -982,17 +991,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
- user_id=str(user_info["user"]),
- device_id=user_info["device_id"],
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
- if user_info["token_id"] is not None:
+ if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"],)
+ user_info.user_id, (user_info.token_id,)
)
async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9c0096bae5..31f91e0a1a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.async_helpers import Linearizer
-from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -928,7 +927,7 @@ class EventCreationHandler:
# Ensure that we can round trip before trying to persist in db
try:
- dump = frozendict_json_encoder.encode(event.content)
+ dump = json_encoder.encode(event.content)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
@@ -1100,34 +1099,13 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
-
- def is_inviter_member_event(e):
- return e.type == EventTypes.Member and e.sender == event.sender
-
- current_state_ids = await context.get_current_state_ids()
-
- # We know this event is not an outlier, so this must be
- # non-None.
- assert current_state_ids is not None
-
- state_to_include_ids = [
- e_id
- for k, e_id in current_state_ids.items()
- if k[0] in self.room_invite_state_types
- or k == (EventTypes.Member, event.sender)
- ]
-
- state_to_include = await self.store.get_events(state_to_include_ids)
-
- event.unsigned["invite_room_state"] = [
- {
- "type": e.type,
- "state_key": e.state_key,
- "content": e.content,
- "sender": e.sender,
- }
- for e in state_to_include.values()
- ]
+ event.unsigned[
+ "invite_room_state"
+ ] = await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_invite_state_types,
+ membership_user_id=event.sender,
+ )
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 49a00eed9c..8e014c9bb5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer
MYPY = False
if MYPY:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class BasePresenceHandler(abc.ABC):
"""Parts of the PresenceHandler that are shared between workers and master"""
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC):
class PresenceHandler(BasePresenceHandler):
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
@@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):
class PresenceEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
@@ -1071,12 +1071,14 @@ class PresenceEventSource:
users_interested_in = await self._get_interested_in(user, explicit_room_id)
- user_ids_changed = set()
+ user_ids_changed = set() # type: Collection[str]
changed = None
if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key)
if changed is not None and len(changed) < 500:
+ assert isinstance(user_ids_changed, set)
+
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
get_updates_counter.labels("stream").inc()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a6f1d21674..ed1ff62599 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = await self.auth.get_user_by_access_token(guest_access_token)
- if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+ if (
+ not user_data.is_guest
+ or UserID.from_string(user_data.user_id).localpart != localpart
+ ):
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
user_id=user_id,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index c5b1f1f1e1..e73031475f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -771,22 +771,29 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
- for invitee in invite_list:
+ # we avoid dropping the lock between invites, as otherwise joins can
+ # start coming in and making the createRoom slow.
+ #
+ # we also don't need to check the requester's shadow-ban here, as we
+ # have already done so above (and potentially emptied invite_list).
+ with (await self.room_member_handler.member_linearizer.queue((room_id,))):
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
content["is_direct"] = is_direct
- # Note that update_membership with an action of "invite" can raise a
- # ShadowBanError, but this was handled above by emptying invite_list.
- _, last_stream_id = await self.room_member_handler.update_membership(
- requester,
- UserID.from_string(invitee),
- room_id,
- "invite",
- ratelimit=False,
- content=content,
- )
+ for invitee in invite_list:
+ (
+ _,
+ last_stream_id,
+ ) = await self.room_member_handler.update_membership_locked(
+ requester,
+ UserID.from_string(invitee),
+ room_id,
+ "invite",
+ ratelimit=False,
+ content=content,
+ )
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0268288600..7e5e53a56f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -327,7 +327,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# haproxy would have timed the request out anyway...
raise SynapseError(504, "took to long to process")
- result = await self._update_membership(
+ result = await self.update_membership_locked(
requester,
target,
room_id,
@@ -342,7 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result
- async def _update_membership(
+ async def update_membership_locked(
self,
requester: Requester,
target: UserID,
@@ -355,6 +355,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[str, int]:
+ """Helper for update_membership.
+
+ Assumes that the membership linearizer is already held for the room.
+ """
content_specified = bool(content)
if content is None:
content = {}
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8324632cb6..f409368802 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -359,7 +359,7 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
- **self._extra_treq_args
+ **self._extra_treq_args,
) # type: defer.Deferred
# we use our own timeout mechanism rather than treq's as a workaround
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 65dbd339ac..c0919f8cb7 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer
from twisted.web.util import redirectTo
-import synapse.events
-import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -620,7 +618,7 @@ def respond_with_json(
if pretty_print:
encoder = iterencode_pretty_printed_json
else:
- if canonical_json or synapse.events.USE_FROZEN_DICTS:
+ if canonical_json:
encoder = iterencode_canonical_json
else:
encoder = _encode_json_bytes
diff --git a/synapse/http/site.py b/synapse/http/site.py
index ddb1770b09..5f0581dc3f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Optional
+from typing import Optional, Union
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -54,9 +55,12 @@ class SynapseRequest(Request):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self._channel = channel # this is used by the tests
- self.authenticated_entity = None
self.start_time = 0.0
+ # The requester, if authenticated. For federation requests this is the
+ # server name, for client requests this is the Requester object.
+ self.requester = None # type: Optional[Union[Requester, str]]
+
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext]
@@ -271,11 +275,23 @@ class SynapseRequest(Request):
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
- # need to decode as it could be raw utf-8 bytes
- # from a IDN servname in an auth header
- authenticated_entity = self.authenticated_entity
- if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
- authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+ # Convert the requester into a string that we can log
+ authenticated_entity = None
+ if isinstance(self.requester, str):
+ authenticated_entity = self.requester
+ elif isinstance(self.requester, Requester):
+ authenticated_entity = self.requester.authenticated_entity
+
+ # If this is a request where the target user doesn't match the user who
+ # authenticated (e.g. and admin is puppetting a user) then we log both.
+ if self.requester.user.to_string() != authenticated_entity:
+ authenticated_entity = "{},{}".format(
+ authenticated_entity, self.requester.user.to_string(),
+ )
+ elif self.requester is not None:
+ # This shouldn't happen, but we log it so we don't lose information
+ # and can see that we're doing something wrong.
+ authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py
index e69de29bb2..b28b7b2ef7 100644
--- a/synapse/logging/__init__.py
+++ b/synapse/logging/__init__.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# These are imported to allow for nicer logging configuration files.
+from synapse.logging._remote import RemoteHandler
+from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+
+__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"]
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 0caf325916..fb937b3f28 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import sys
import traceback
from collections import deque
@@ -21,10 +22,11 @@ from math import floor
from typing import Callable, Optional
import attr
+from typing_extensions import Deque
from zope.interface import implementer
from twisted.application.internet import ClientService
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
@@ -32,7 +34,9 @@ from twisted.internet.endpoints import (
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
-from twisted.logger import ILogObserver, Logger, LogLevel
+from twisted.python.failure import Failure
+
+logger = logging.getLogger(__name__)
@attr.s
@@ -45,11 +49,11 @@ class LogProducer:
Args:
buffer: Log buffer to read logs from.
transport: Transport to write to.
- format_event: A callable to format the log entry to a string.
+ format: A callable to format the log record to a string.
"""
transport = attr.ib(type=ITransport)
- format_event = attr.ib(type=Callable[[dict], str])
+ _format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -61,16 +65,19 @@ class LogProducer:
self._buffer = deque()
def resumeProducing(self):
+ # If we're already producing, nothing to do.
self._paused = False
+ # Loop until paused.
while self._paused is False and (self._buffer and self.transport.connected):
try:
- # Request the next event and format it.
- event = self._buffer.popleft()
- msg = self.format_event(event)
+ # Request the next record and format it.
+ record = self._buffer.popleft()
+ msg = self._format(record)
# Send it as a new line over the transport.
self.transport.write(msg.encode("utf8"))
+ self.transport.write(b"\n")
except Exception:
# Something has gone wrong writing to the transport -- log it
# and break out of the while.
@@ -78,76 +85,85 @@ class LogProducer:
break
-@attr.s
-@implementer(ILogObserver)
-class TCPLogObserver:
+class RemoteHandler(logging.Handler):
"""
- An IObserver that writes JSON logs to a TCP target.
+ An logging handler that writes logs to a TCP target.
Args:
- hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
- format_event: A callable to format the log entry to a string.
maximum_buffer: The maximum buffer size.
"""
- hs = attr.ib()
- host = attr.ib(type=str)
- port = attr.ib(type=int)
- format_event = attr.ib(type=Callable[[dict], str])
- maximum_buffer = attr.ib(type=int)
- _buffer = attr.ib(default=attr.Factory(deque), type=deque)
- _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
- _logger = attr.ib(default=attr.Factory(Logger))
- _producer = attr.ib(default=None, type=Optional[LogProducer])
-
- def start(self) -> None:
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ maximum_buffer: int = 1000,
+ level=logging.NOTSET,
+ _reactor=None,
+ ):
+ super().__init__(level=level)
+ self.host = host
+ self.port = port
+ self.maximum_buffer = maximum_buffer
+
+ self._buffer = deque() # type: Deque[logging.LogRecord]
+ self._connection_waiter = None # type: Optional[Deferred]
+ self._producer = None # type: Optional[LogProducer]
# Connect without DNS lookups if it's a direct IP.
+ if _reactor is None:
+ from twisted.internet import reactor
+
+ _reactor = reactor
+
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(
- self.hs.get_reactor(), self.host, self.port
- )
+ endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
elif isinstance(ip, IPv6Address):
- endpoint = TCP6ClientEndpoint(
- self.hs.get_reactor(), self.host, self.port
- )
+ endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
raise ValueError("Unknown IP address provided: %s" % (self.host,))
except ValueError:
- endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+ endpoint = HostnameEndpoint(_reactor, self.host, self.port)
factory = Factory.forProtocol(Protocol)
- self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+ self._service = ClientService(endpoint, factory, clock=_reactor)
self._service.startService()
+ self._stopping = False
self._connect()
- def stop(self):
+ def close(self):
+ self._stopping = True
self._service.stopService()
def _connect(self) -> None:
"""
Triggers an attempt to connect then write to the remote if not already writing.
"""
+ # Do not attempt to open multiple connections.
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
- @self._connection_waiter.addErrback
- def fail(r):
- r.printTraceback(file=sys.__stderr__)
+ def fail(failure: Failure) -> None:
+ # If the Deferred was cancelled (e.g. during shutdown) do not try to
+ # reconnect (this will cause an infinite loop of errors).
+ if failure.check(CancelledError) and self._stopping:
+ return
+
+ # For a different error, print the traceback and re-connect.
+ failure.printTraceback(file=sys.__stderr__)
self._connection_waiter = None
self._connect()
- @self._connection_waiter.addCallback
- def writer(r):
+ def writer(result: Protocol) -> None:
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
- if self._producer and r.transport is self._producer.transport:
+ if self._producer and result.transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -158,29 +174,29 @@ class TCPLogObserver:
# Make a new producer and start it.
self._producer = LogProducer(
- buffer=self._buffer,
- transport=r.transport,
- format_event=self.format_event,
+ buffer=self._buffer, transport=result.transport, format=self.format,
)
- r.transport.registerProducer(self._producer, True)
+ result.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
+ self._connection_waiter.addCallbacks(writer, fail)
+
def _handle_pressure(self) -> None:
"""
- Handle backpressure by shedding events.
+ Handle backpressure by shedding records.
The buffer will, in this order, until the buffer is below the maximum:
- - Shed DEBUG events
- - Shed INFO events
- - Shed the middle 50% of the events.
+ - Shed DEBUG records.
+ - Shed INFO records.
+ - Shed the middle 50% of the records.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
- filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
+ filter(lambda record: record.levelno > logging.DEBUG, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
@@ -188,7 +204,7 @@ class TCPLogObserver:
# Strip out INFOs
self._buffer = deque(
- filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
+ filter(lambda record: record.levelno > logging.INFO, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
@@ -209,17 +225,17 @@ class TCPLogObserver:
self._buffer.extend(reversed(end_buffer))
- def __call__(self, event: dict) -> None:
- self._buffer.append(event)
+ def emit(self, record: logging.LogRecord) -> None:
+ self._buffer.append(record)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
- # If handling backpressure fails,clear the buffer and log the
+ # If handling backpressure fails, clear the buffer and log the
# exception.
self._buffer.clear()
- self._logger.failure("Failed clearing backpressure")
+ logger.warning("Failed clearing backpressure")
# Try and write immediately.
self._connect()
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 0fc2ea609e..14d9c104c2 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -12,138 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import os.path
-import sys
-import typing
-import warnings
-from typing import List
+from typing import Any, Dict, Generator, Optional, Tuple
-import attr
-from constantly import NamedConstant, Names, ValueConstant, Values
-from zope.interface import implementer
-
-from twisted.logger import (
- FileLogObserver,
- FilteringLogObserver,
- ILogObserver,
- LogBeginner,
- Logger,
- LogLevel,
- LogLevelFilterPredicate,
- LogPublisher,
- eventAsText,
- jsonFileLogObserver,
-)
+from constantly import NamedConstant, Names
from synapse.config._base import ConfigError
-from synapse.logging._terse_json import (
- TerseJSONToConsoleLogObserver,
- TerseJSONToTCPLogObserver,
-)
-from synapse.logging.context import current_context
-
-
-def stdlib_log_level_to_twisted(level: str) -> LogLevel:
- """
- Convert a stdlib log level to Twisted's log level.
- """
- lvl = level.lower().replace("warning", "warn")
- return LogLevel.levelWithName(lvl)
-
-
-@attr.s
-@implementer(ILogObserver)
-class LogContextObserver:
- """
- An ILogObserver which adds Synapse-specific log context information.
-
- Attributes:
- observer (ILogObserver): The target parent observer.
- """
-
- observer = attr.ib()
-
- def __call__(self, event: dict) -> None:
- """
- Consume a log event and emit it to the parent observer after filtering
- and adding log context information.
-
- Args:
- event (dict)
- """
- # Filter out some useless events that Twisted outputs
- if "log_text" in event:
- if event["log_text"].startswith("DNSDatagramProtocol starting on "):
- return
-
- if event["log_text"].startswith("(UDP Port "):
- return
-
- if event["log_text"].startswith("Timing out client") or event[
- "log_format"
- ].startswith("Timing out client"):
- return
-
- context = current_context()
-
- # Copy the context information to the log event.
- context.copy_to_twisted_log_entry(event)
-
- self.observer(event)
-
-
-class PythonStdlibToTwistedLogger(logging.Handler):
- """
- Transform a Python stdlib log message into a Twisted one.
- """
-
- def __init__(self, observer, *args, **kwargs):
- """
- Args:
- observer (ILogObserver): A Twisted logging observer.
- *args, **kwargs: Args/kwargs to be passed to logging.Handler.
- """
- self.observer = observer
- super().__init__(*args, **kwargs)
-
- def emit(self, record: logging.LogRecord) -> None:
- """
- Emit a record to Twisted's observer.
-
- Args:
- record (logging.LogRecord)
- """
-
- self.observer(
- {
- "log_time": record.created,
- "log_text": record.getMessage(),
- "log_format": "{log_text}",
- "log_namespace": record.name,
- "log_level": stdlib_log_level_to_twisted(record.levelname),
- }
- )
-
-
-def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
- """
- A log observer that formats events like the traditional log formatter and
- sends them to `outFile`.
-
- Args:
- outFile (file object): The file object to write to.
- """
-
- def formatEvent(_event: dict) -> str:
- event = dict(_event)
- event["log_level"] = event["log_level"].name.upper()
- event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
- event.get("log_format", "{log_text}") or "{log_text}"
- )
- return eventAsText(event, includeSystem=False) + "\n"
-
- return FileLogObserver(outFile, formatEvent)
class DrainType(Names):
@@ -155,30 +29,12 @@ class DrainType(Names):
NETWORK_JSON_TERSE = NamedConstant()
-class OutputPipeType(Values):
- stdout = ValueConstant(sys.__stdout__)
- stderr = ValueConstant(sys.__stderr__)
-
-
-@attr.s
-class DrainConfiguration:
- name = attr.ib()
- type = attr.ib()
- location = attr.ib()
- options = attr.ib(default=None)
-
-
-@attr.s
-class NetworkJSONTerseOptions:
- maximum_buffer = attr.ib(type=int)
-
-
-DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
+DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
def parse_drain_configs(
drains: dict,
-) -> typing.Generator[DrainConfiguration, None, None]:
+) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
"""
Parse the drain configurations.
@@ -186,11 +42,12 @@ def parse_drain_configs(
drains (dict): A list of drain configurations.
Yields:
- DrainConfiguration instances.
+ dict instances representing a logging handler.
Raises:
ConfigError: If any of the drain configuration items are invalid.
"""
+
for name, config in drains.items():
if "type" not in config:
raise ConfigError("Logging drains require a 'type' key.")
@@ -202,6 +59,18 @@ def parse_drain_configs(
"%s is not a known logging drain type." % (config["type"],)
)
+ # Either use the default formatter or the tersejson one.
+ if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
+ formatter = "json" # type: Optional[str]
+ elif logging_type in (
+ DrainType.CONSOLE_JSON_TERSE,
+ DrainType.NETWORK_JSON_TERSE,
+ ):
+ formatter = "tersejson"
+ else:
+ # A formatter of None implies using the default formatter.
+ formatter = None
+
if logging_type in [
DrainType.CONSOLE,
DrainType.CONSOLE_JSON,
@@ -217,9 +86,11 @@ def parse_drain_configs(
% (logging_type,)
)
- pipe = OutputPipeType.lookupByName(location).value
-
- yield DrainConfiguration(name=name, type=logging_type, location=pipe)
+ yield name, {
+ "class": "logging.StreamHandler",
+ "formatter": formatter,
+ "stream": "ext://sys." + location,
+ }
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
if "location" not in config:
@@ -233,18 +104,25 @@ def parse_drain_configs(
"File paths need to be absolute, '%s' is a relative path"
% (location,)
)
- yield DrainConfiguration(name=name, type=logging_type, location=location)
+
+ yield name, {
+ "class": "logging.FileHandler",
+ "formatter": formatter,
+ "filename": location,
+ }
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
host = config.get("host")
port = config.get("port")
maximum_buffer = config.get("maximum_buffer", 1000)
- yield DrainConfiguration(
- name=name,
- type=logging_type,
- location=(host, port),
- options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
- )
+
+ yield name, {
+ "class": "synapse.logging.RemoteHandler",
+ "formatter": formatter,
+ "host": host,
+ "port": port,
+ "maximum_buffer": maximum_buffer,
+ }
else:
raise ConfigError(
@@ -253,126 +131,29 @@ def parse_drain_configs(
)
-class StoppableLogPublisher(LogPublisher):
+def setup_structured_logging(log_config: dict,) -> dict:
"""
- A log publisher that can tell its observers to shut down any external
- communications.
- """
-
- def stop(self):
- for obs in self._observers:
- if hasattr(obs, "stop"):
- obs.stop()
-
-
-def setup_structured_logging(
- hs,
- config,
- log_config: dict,
- logBeginner: LogBeginner,
- redirect_stdlib_logging: bool = True,
-) -> LogPublisher:
- """
- Set up Twisted's structured logging system.
-
- Args:
- hs: The homeserver to use.
- config (HomeserverConfig): The configuration of the Synapse homeserver.
- log_config (dict): The log configuration to use.
+ Convert a legacy structured logging configuration (from Synapse < v1.23.0)
+ to one compatible with the new standard library handlers.
"""
- if config.no_redirect_stdio:
- raise ConfigError(
- "no_redirect_stdio cannot be defined using structured logging."
- )
-
- logger = Logger()
-
if "drains" not in log_config:
raise ConfigError("The logging configuration requires a list of drains.")
- observers = [] # type: List[ILogObserver]
-
- for observer in parse_drain_configs(log_config["drains"]):
- # Pipe drains
- if observer.type == DrainType.CONSOLE:
- logger.debug(
- "Starting up the {name} console logger drain", name=observer.name
- )
- observers.append(SynapseFileLogObserver(observer.location))
- elif observer.type == DrainType.CONSOLE_JSON:
- logger.debug(
- "Starting up the {name} JSON console logger drain", name=observer.name
- )
- observers.append(jsonFileLogObserver(observer.location))
- elif observer.type == DrainType.CONSOLE_JSON_TERSE:
- logger.debug(
- "Starting up the {name} terse JSON console logger drain",
- name=observer.name,
- )
- observers.append(
- TerseJSONToConsoleLogObserver(observer.location, metadata={})
- )
-
- # File drains
- elif observer.type == DrainType.FILE:
- logger.debug("Starting up the {name} file logger drain", name=observer.name)
- log_file = open(observer.location, "at", buffering=1, encoding="utf8")
- observers.append(SynapseFileLogObserver(log_file))
- elif observer.type == DrainType.FILE_JSON:
- logger.debug(
- "Starting up the {name} JSON file logger drain", name=observer.name
- )
- log_file = open(observer.location, "at", buffering=1, encoding="utf8")
- observers.append(jsonFileLogObserver(log_file))
-
- elif observer.type == DrainType.NETWORK_JSON_TERSE:
- metadata = {"server_name": hs.config.server_name}
- log_observer = TerseJSONToTCPLogObserver(
- hs=hs,
- host=observer.location[0],
- port=observer.location[1],
- metadata=metadata,
- maximum_buffer=observer.options.maximum_buffer,
- )
- log_observer.start()
- observers.append(log_observer)
- else:
- # We should never get here, but, just in case, throw an error.
- raise ConfigError("%s drain type cannot be configured" % (observer.type,))
-
- publisher = StoppableLogPublisher(*observers)
- log_filter = LogLevelFilterPredicate()
-
- for namespace, namespace_config in log_config.get(
- "loggers", DEFAULT_LOGGERS
- ).items():
- # Set the log level for twisted.logger.Logger namespaces
- log_filter.setLogLevelForNamespace(
- namespace,
- stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
- )
-
- # Also set the log levels for the stdlib logger namespaces, to prevent
- # them getting to PythonStdlibToTwistedLogger and having to be formatted
- if "level" in namespace_config:
- logging.getLogger(namespace).setLevel(namespace_config.get("level"))
-
- f = FilteringLogObserver(publisher, [log_filter])
- lco = LogContextObserver(f)
-
- if redirect_stdlib_logging:
- stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
- stdliblogger = logging.getLogger()
- stdliblogger.addHandler(stuff_into_twisted)
-
- # Always redirect standard I/O, otherwise other logging outputs might miss
- # it.
- logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
+ new_config = {
+ "version": 1,
+ "formatters": {
+ "json": {"class": "synapse.logging.JsonFormatter"},
+ "tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
+ },
+ "handlers": {},
+ "loggers": log_config.get("loggers", DEFAULT_LOGGERS),
+ "root": {"handlers": []},
+ }
- return publisher
+ for handler_name, handler in parse_drain_configs(log_config["drains"]):
+ new_config["handlers"][handler_name] = handler
+ # Add each handler to the root logger.
+ new_config["root"]["handlers"].append(handler_name)
-def reload_structured_logging(*args, log_config=None) -> None:
- warnings.warn(
- "Currently the structured logging system can not be reloaded, doing nothing"
- )
+ return new_config
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 9b46956ca9..2fbf5549a1 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -16,141 +16,65 @@
"""
Log formatters that output terse JSON.
"""
-
import json
-from typing import IO
-
-from twisted.logger import FileLogObserver
-
-from synapse.logging._remote import TCPLogObserver
+import logging
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
-
-def flatten_event(event: dict, metadata: dict, include_time: bool = False):
- """
- Flatten a Twisted logging event to an dictionary capable of being sent
- as a log event to a logging aggregation system.
-
- The format is vastly simplified and is not designed to be a "human readable
- string" in the sense that traditional logs are. Instead, the structure is
- optimised for searchability and filtering, with human-understandable log
- keys.
-
- Args:
- event (dict): The Twisted logging event we are flattening.
- metadata (dict): Additional data to include with each log message. This
- can be information like the server name. Since the target log
- consumer does not know who we are other than by host IP, this
- allows us to forward through static information.
- include_time (bool): Should we include the `time` key? If False, the
- event time is stripped from the event.
- """
- new_event = {}
-
- # If it's a failure, make the new event's log_failure be the traceback text.
- if "log_failure" in event:
- new_event["log_failure"] = event["log_failure"].getTraceback()
-
- # If it's a warning, copy over a string representation of the warning.
- if "warning" in event:
- new_event["warning"] = str(event["warning"])
-
- # Stdlib logging events have "log_text" as their human-readable portion,
- # Twisted ones have "log_format". For now, include the log_format, so that
- # context only given in the log format (e.g. what is being logged) is
- # available.
- if "log_text" in event:
- new_event["log"] = event["log_text"]
- else:
- new_event["log"] = event["log_format"]
-
- # We want to include the timestamp when forwarding over the network, but
- # exclude it when we are writing to stdout. This is because the log ingester
- # (e.g. logstash, fluentd) can add its own timestamp.
- if include_time:
- new_event["time"] = round(event["log_time"], 2)
-
- # Convert the log level to a textual representation.
- new_event["level"] = event["log_level"].name.upper()
-
- # Ignore these keys, and do not transfer them over to the new log object.
- # They are either useless (isError), transferred manually above (log_time,
- # log_level, etc), or contain Python objects which are not useful for output
- # (log_logger, log_source).
- keys_to_delete = [
- "isError",
- "log_failure",
- "log_format",
- "log_level",
- "log_logger",
- "log_source",
- "log_system",
- "log_time",
- "log_text",
- "observer",
- "warning",
- ]
-
- # If it's from the Twisted legacy logger (twisted.python.log), it adds some
- # more keys we want to purge.
- if event.get("log_namespace") == "log_legacy":
- keys_to_delete.extend(["message", "system", "time"])
-
- # Rather than modify the dictionary in place, construct a new one with only
- # the content we want. The original event should be considered 'frozen'.
- for key in event.keys():
-
- if key in keys_to_delete:
- continue
-
- if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
- # If it's a plain type, include it as is.
- new_event[key] = event[key]
- else:
- # If it's not one of those basic types, write out a string
- # representation. This should probably be a warning in development,
- # so that we are sure we are only outputting useful data.
- new_event[key] = str(event[key])
-
- # Add the metadata information to the event (e.g. the server_name).
- new_event.update(metadata)
-
- return new_event
-
-
-def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
- """
- A log observer that formats events to a flattened JSON representation.
-
- Args:
- outFile: The file object to write to.
- metadata: Metadata to be added to each log object.
- """
-
- def formatEvent(_event: dict) -> str:
- flattened = flatten_event(_event, metadata)
- return _encoder.encode(flattened) + "\n"
-
- return FileLogObserver(outFile, formatEvent)
-
-
-def TerseJSONToTCPLogObserver(
- hs, host: str, port: int, metadata: dict, maximum_buffer: int
-) -> FileLogObserver:
- """
- A log observer that formats events to a flattened JSON representation.
-
- Args:
- hs (HomeServer): The homeserver that is being logged for.
- host: The host of the logging target.
- port: The logging target's port.
- metadata: Metadata to be added to each log object.
- maximum_buffer: The maximum buffer size.
- """
-
- def formatEvent(_event: dict) -> str:
- flattened = flatten_event(_event, metadata, include_time=True)
- return _encoder.encode(flattened) + "\n"
-
- return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
+# The properties of a standard LogRecord.
+_LOG_RECORD_ATTRIBUTES = {
+ "args",
+ "asctime",
+ "created",
+ "exc_info",
+ # exc_text isn't a public attribute, but is used to cache the result of formatException.
+ "exc_text",
+ "filename",
+ "funcName",
+ "levelname",
+ "levelno",
+ "lineno",
+ "message",
+ "module",
+ "msecs",
+ "msg",
+ "name",
+ "pathname",
+ "process",
+ "processName",
+ "relativeCreated",
+ "stack_info",
+ "thread",
+ "threadName",
+}
+
+
+class JsonFormatter(logging.Formatter):
+ def format(self, record: logging.LogRecord) -> str:
+ event = {
+ "log": record.getMessage(),
+ "namespace": record.name,
+ "level": record.levelname,
+ }
+
+ return self._format(record, event)
+
+ def _format(self, record: logging.LogRecord, event: dict) -> str:
+ # Add any extra attributes to the event.
+ for key, value in record.__dict__.items():
+ if key not in _LOG_RECORD_ATTRIBUTES:
+ event[key] = value
+
+ return _encoder.encode(event)
+
+
+class TerseJsonFormatter(JsonFormatter):
+ def format(self, record: logging.LogRecord) -> str:
+ event = {
+ "log": record.getMessage(),
+ "namespace": record.name,
+ "level": record.levelname,
+ "time": round(record.created, 2),
+ }
+
+ return self._format(record, event)
diff --git a/synapse/logging/filter.py b/synapse/logging/filter.py
new file mode 100644
index 0000000000..1baf8dd679
--- /dev/null
+++ b/synapse/logging/filter.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from typing_extensions import Literal
+
+
+class MetadataFilter(logging.Filter):
+ """Logging filter that adds constant values to each record.
+
+ Args:
+ metadata: Key-value pairs to add to each record.
+ """
+
+ def __init__(self, metadata: dict):
+ self._metadata = metadata
+
+ def filter(self, record: logging.LogRecord) -> Literal[True]:
+ for key, value in self._metadata.items():
+ setattr(record, key, value)
+ return True
diff --git a/synapse/notifier.py b/synapse/notifier.py
index eb56b26f21..a17352ef46 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -28,6 +28,7 @@ from typing import (
Union,
)
+import attr
from prometheus_client import Counter
from twisted.internet import defer
@@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
return bool(self.events)
+@attr.s(slots=True, frozen=True)
+class _PendingRoomEventEntry:
+ event_pos = attr.ib(type=PersistedEventPosition)
+ extra_users = attr.ib(type=Collection[UserID])
+
+ room_id = attr.ib(type=str)
+ type = attr.ib(type=str)
+ state_key = attr.ib(type=Optional[str])
+ membership = attr.ib(type=Optional[str])
+
+
class Notifier:
""" This class is responsible for notifying any listeners when there are
new events available for it.
@@ -190,9 +202,7 @@ class Notifier:
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
- self.pending_new_room_events = (
- []
- ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
+ self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -255,7 +265,29 @@ class Notifier:
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
- """ Used by handlers to inform the notifier something has happened
+ """Unwraps event and calls `on_new_room_event_args`.
+ """
+ self.on_new_room_event_args(
+ event_pos=event_pos,
+ room_id=event.room_id,
+ event_type=event.type,
+ state_key=event.get("state_key"),
+ membership=event.content.get("membership"),
+ max_room_stream_token=max_room_stream_token,
+ extra_users=extra_users,
+ )
+
+ def on_new_room_event_args(
+ self,
+ room_id: str,
+ event_type: str,
+ state_key: Optional[str],
+ membership: Optional[str],
+ event_pos: PersistedEventPosition,
+ max_room_stream_token: RoomStreamToken,
+ extra_users: Collection[UserID] = [],
+ ):
+ """Used by handlers to inform the notifier something has happened
in the room, room event wise.
This triggers the notifier to wake up any listeners that are
@@ -266,7 +298,16 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append((event_pos, event, extra_users))
+ self.pending_new_room_events.append(
+ _PendingRoomEventEntry(
+ event_pos=event_pos,
+ extra_users=extra_users,
+ room_id=room_id,
+ type=event_type,
+ state_key=state_key,
+ membership=membership,
+ )
+ )
self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
@@ -284,18 +325,19 @@ class Notifier:
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
- for event_pos, event, extra_users in pending:
- if event_pos.persisted_after(max_room_stream_token):
- self.pending_new_room_events.append((event_pos, event, extra_users))
+ for entry in pending:
+ if entry.event_pos.persisted_after(max_room_stream_token):
+ self.pending_new_room_events.append(entry)
else:
if (
- event.type == EventTypes.Member
- and event.membership == Membership.JOIN
+ entry.type == EventTypes.Member
+ and entry.membership == Membership.JOIN
+ and entry.state_key
):
- self._user_joined_room(event.state_key, event.room_id)
+ self._user_joined_room(entry.state_key, entry.room_id)
- users.update(extra_users)
- rooms.add(event.room_id)
+ users.update(entry.extra_users)
+ rooms.add(entry.room_id)
if users or rooms:
self.on_new_event(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9b5478b53..82a72dc34f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,8 @@
# limitations under the License.
import logging
-from collections import namedtuple
+import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import lru_cache
+from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
@@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = await self._get_rules_for_room(room_id)
+ rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = await rules_for_room.get_rules(event, context)
@@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
return rules_by_user
- @cached()
+ @lru_cache()
def _get_rules_for_room(self, room_id):
"""Get the current RulesForRoom object for the given room id
@@ -275,12 +276,14 @@ class RulesForRoom:
the entire cache for the room.
"""
- def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+ def __init__(
+ self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ ):
"""
Args:
hs (HomeServer)
room_id (str)
- rules_for_room_cache(Cache): The cache object that caches these
+ rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric)
"""
@@ -489,13 +492,21 @@ class RulesForRoom:
self.state_group = state_group
-class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
+@attr.attrs(slots=True, frozen=True)
+class _Invalidation:
+ # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
+ # which means that it it is stored on the bulk_get_push_rules cache entry. In order
+ # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+ # we need to ensure that two _Invalidation objects are "equal" if they refer to the
+ # same `cache` and `room_id`.
+ #
+ # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
+ # set `frozen=True`.
+
+ cache = attr.ib(type=LruCache)
+ room_id = attr.ib(type=str)
+
def __call__(self):
- rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
+ rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index e7cc74a5d2..f0c37eaf5e 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index fc129dbaa7..8fa104c8d3 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e27ee216f0..2618eb1e53 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -141,21 +141,25 @@ class ReplicationDataHandler:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
+ assert isinstance(row.data, EventsStreamEventRow)
- event = await self.store.get_event(
- row.data.event_id, allow_rejected=True
- )
- if event.rejected_reason:
+ if row.data.rejected:
continue
extra_users = () # type: Tuple[UserID, ...]
- if event.type == EventTypes.Member:
- extra_users = (UserID.from_string(event.state_key),)
+ if row.data.type == EventTypes.Member and row.data.state_key:
+ extra_users = (UserID.from_string(row.data.state_key),)
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
- self.notifier.on_new_room_event(
- event, event_pos, max_token, extra_users
+ self.notifier.on_new_room_event_args(
+ event_pos=event_pos,
+ max_room_stream_token=max_token,
+ extra_users=extra_users,
+ room_id=row.data.room_id,
+ event_type=row.data.type,
+ state_key=row.data.state_key,
+ membership=row.data.membership,
)
# Notify any waiting deferreds. The list is ordered by position so we
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 82e9e0d64e..86a62b71eb 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import List, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
- relates_to = attr.ib() # str, optional
+ event_id = attr.ib(type=str)
+ room_id = attr.ib(type=str)
+ type = attr.ib(type=str)
+ state_key = attr.ib(type=Optional[str])
+ redacts = attr.ib(type=Optional[str])
+ relates_to = attr.ib(type=Optional[str])
+ membership = attr.ib(type=Optional[str])
+ rejected = attr.ib(type=bool)
@attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
NAME = "events"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index a79996cae1..fa7e9e4043 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
+ PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
@@ -226,8 +227,9 @@ def register_servlets(hs, http_server):
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
- EventReportsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
+ EventReportsRestServlet(hs).register(http_server)
+ PushersRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 933bb45346..b337311a37 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -39,6 +39,17 @@ from synapse.types import JsonDict, UserID
logger = logging.getLogger(__name__)
+_GET_PUSHERS_ALLOWED_KEYS = {
+ "app_display_name",
+ "app_id",
+ "data",
+ "device_display_name",
+ "kind",
+ "lang",
+ "profile_tag",
+ "pushkey",
+}
+
class UsersRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
@@ -713,6 +724,47 @@ class UserMembershipRestServlet(RestServlet):
return 200, ret
+class PushersRestServlet(RestServlet):
+ """
+ Gets information about all pushers for a specific `user_id`.
+
+ Example:
+ http://localhost:8008/_synapse/admin/v1/users/
+ @user:server/pushers
+
+ Returns:
+ pushers: Dictionary containing pushers information.
+ total: Number of pushers in dictonary `pushers`.
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
+
+ def __init__(self, hs):
+ self.is_mine = hs.is_mine
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only lookup local users")
+
+ if not await self.store.get_user_by_id(user_id):
+ raise NotFoundError("User not found")
+
+ pushers = await self.store.get_pushers_by_user_id(user_id)
+
+ filtered_pushers = [
+ {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
+ for p in pushers
+ ]
+
+ return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
+
+
class UserMediaRestServlet(RestServlet):
"""
Gets information about all uploaded local media for a specific `user_id`.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 5cce7237a0..9cac74ebd8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -305,15 +305,12 @@ class MediaRepository:
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
- if media_info:
- file_id = media_info["filesystem_id"]
- else:
- file_id = random_string(24)
-
- file_info = FileInfo(server_name, file_id)
# If we have an entry in the DB, try and look for it
if media_info:
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
@@ -324,14 +321,34 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
- media_info = await self._download_remote_file(server_name, media_id, file_id)
+ try:
+ media_info = await self._download_remote_file(server_name, media_id,)
+ except SynapseError:
+ raise
+ except Exception as e:
+ # An exception may be because we downloaded media in another
+ # process, so let's check if we magically have the media.
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
+ if not media_info:
+ raise e
+
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
+ # We generate thumbnails even if another process downloaded the media
+ # as a) it's conceivable that the other download request dies before it
+ # generates thumbnails, but mainly b) we want to be sure the thumbnails
+ # have finished being generated before responding to the client,
+ # otherwise they'll request thumbnails and get a 404 if they're not
+ # ready yet.
+ await self._generate_thumbnails(
+ server_name, media_id, file_id, media_info["media_type"]
+ )
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(
- self, server_name: str, media_id: str, file_id: str
- ) -> dict:
+ async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -346,6 +363,8 @@ class MediaRepository:
The media info of the file.
"""
+ file_id = random_string(24)
+
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -401,22 +420,32 @@ class MediaRepository:
await finish()
- media_type = headers[b"Content-Type"][0].decode("ascii")
- upload_name = get_filename_from_headers(headers)
- time_now_ms = self.clock.time_msec()
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ upload_name = get_filename_from_headers(headers)
+ time_now_ms = self.clock.time_msec()
+
+ # Multiple remote media download requests can race (when using
+ # multiple media repos), so this may throw a violation constraint
+ # exception. If it does we'll delete the newly downloaded file from
+ # disk (as we're in the ctx manager).
+ #
+ # However: we've already called `finish()` so we may have also
+ # written to the storage providers. This is preferable to the
+ # alternative where we call `finish()` *after* this, where we could
+ # end up having an entry in the DB but fail to write the files to
+ # the storage providers.
+ await self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
logger.info("Stored remote media in file %r", fname)
- await self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
-
media_info = {
"media_type": media_type,
"media_length": length,
@@ -425,8 +454,6 @@ class MediaRepository:
"filesystem_id": file_id,
}
- await self._generate_thumbnails(server_name, media_id, file_id, media_type)
-
return media_info
def _get_thumbnail_requirements(self, media_type):
@@ -692,42 +719,60 @@ class MediaRepository:
if not t_byte_source:
continue
- try:
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
- url_cache=url_cache,
- )
-
- output_path = await self.media_storage.store_file(
- t_byte_source, file_info
- )
- finally:
- t_byte_source.close()
-
- t_len = os.path.getsize(output_path)
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
- # Write to database
- if server_name:
- await self.store.store_remote_media_thumbnail(
- server_name,
- media_id,
- file_id,
- t_width,
- t_height,
- t_type,
- t_method,
- t_len,
- )
- else:
- await self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ await self.media_storage.write_to_file(t_byte_source, f)
+ await finish()
+ finally:
+ t_byte_source.close()
+
+ t_len = os.path.getsize(fname)
+
+ # Write to database
+ if server_name:
+ # Multiple remote media download requests can race (when
+ # using multiple media repos), so this may throw a violation
+ # constraint exception. If it does we'll delete the newly
+ # generated thumbnail from disk (as we're in the ctx
+ # manager).
+ #
+ # However: we've already called `finish()` so we may have
+ # also written to the storage providers. This is preferable
+ # to the alternative where we call `finish()` *after* this,
+ # where we could end up having an entry in the DB but fail
+ # to write the files to the storage providers.
+ try:
+ await self.store.store_remote_media_thumbnail(
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
+ )
+ except Exception as e:
+ thumbnail_exists = await self.store.get_remote_media_thumbnail(
+ server_name, media_id, t_width, t_height, t_type,
+ )
+ if not thumbnail_exists:
+ raise e
+ else:
+ await self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
return {"width": m_width, "height": m_height}
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage:
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
+ self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@@ -70,13 +71,16 @@ class MediaStorage:
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- await defer_to_thread(
- self.hs.get_reactor(), _write_file_synchronously, source, f
- )
+ await self.write_to_file(source, f)
await finish_cb()
return fname
+ async def write_to_file(self, source: IO, output: IO):
+ """Asynchronously write the `source` to `output`.
+ """
+ await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
@@ -112,14 +116,20 @@ class MediaStorage:
finished_called = [False]
- async def finish():
- for provider in self.storage_providers:
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
try:
with open(fname, "wb") as f:
+
+ async def finish():
+ # Ensure that all writes have been flushed and close the
+ # file.
+ f.flush()
+ f.close()
+
+ for provider in self.storage_providers:
+ await provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
yield f, fname, finish
except Exception:
try:
@@ -210,7 +220,7 @@ class MediaStorage:
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.hs.get_reactor()
+ open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0217e63108..a0572b2952 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -94,7 +94,7 @@ def make_pool(
cp_openfun=lambda conn: engine.on_new_connection(
LoggingDatabaseConnection(conn, engine, "on_new_connection")
),
- **db_config.config.get("args", {})
+ **db_config.config.get("args", {}),
)
@@ -632,7 +632,7 @@ class DatabasePool:
func,
*args,
db_autocommit=db_autocommit,
- **kwargs
+ **kwargs,
)
for after_callback, after_args, after_kwargs in after_callbacks:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 637a938bac..26eef6eb61 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,21 +15,31 @@
# limitations under the License.
import logging
import re
-from typing import List
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
-from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.appservice import (
+ ApplicationService,
+ ApplicationServiceState,
+ AppServiceTransaction,
+)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+ services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
@@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
- exclusive_user_regex = re.compile(exclusive_user_regex)
+ exclusive_user_pattern = re.compile(
+ exclusive_user_regex
+ ) # type: Optional[Pattern]
else:
# We handle this case specially otherwise the constructed regex
# will always match
- exclusive_user_regex = None
+ exclusive_user_pattern = None
- return exclusive_user_regex
+ return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
- def get_if_app_services_interested_in_user(self, user_id):
+ def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
if self.exclusive_user_regex:
@@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else:
return False
- def get_app_service_by_user_id(self, user_id):
+ def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
@@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service.
Args:
- user_id(str): The user ID to see if it is an application service.
+ user_id: The user ID to see if it is an application service.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.sender == user_id:
return service
return None
- def get_app_service_by_token(self, token):
+ def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token.
Args:
- token (str): The application service token.
+ token: The application service token.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.token == token:
return service
return None
- def get_app_service_by_id(self, as_id):
+ def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID.
Args:
- as_id (str): The application service ID.
+ as_id: The application service ID.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.id == as_id:
@@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
- async def get_appservices_by_state(self, state):
+ async def get_appservices_by_state(
+ self, state: ApplicationServiceState
+ ) -> List[ApplicationService]:
"""Get a list of application services based on their state.
Args:
- state(ApplicationServiceState): The state to filter on.
+ state: The state to filter on.
Returns:
A list of ApplicationServices, which may be empty.
"""
@@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
- async def get_appservice_state(self, service):
+ async def get_appservice_state(
+ self, service: ApplicationService
+ ) -> Optional[ApplicationServiceState]:
"""Get the application service state.
Args:
- service(ApplicationService): The service whose state to set.
+ service: The service whose state to set.
Returns:
- An ApplicationServiceState.
+ An ApplicationServiceState or none.
"""
result = await self.db_pool.simple_select_one(
"application_services_state",
@@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- async def set_appservice_state(self, service, state) -> None:
+ async def set_appservice_state(
+ self, service: ApplicationService, state: ApplicationServiceState
+ ) -> None:
"""Set the application service state.
Args:
- service(ApplicationService): The service whose state to set.
- state(ApplicationServiceState): The connectivity state to apply.
+ service: The service whose state to set.
+ state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
@@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
"create_appservice_txn", _create_appservice_txn
)
- async def complete_appservice_txn(self, txn_id, service) -> None:
+ async def complete_appservice_txn(
+ self, txn_id: int, service: ApplicationService
+ ) -> None:
"""Completes an application service transaction.
Args:
- txn_id(str): The transaction ID being completed.
- service(ApplicationService): The application service which was sent
- this transaction.
+ txn_id: The transaction ID being completed.
+ service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
@@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
+ if (txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
@@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn
)
- async def get_oldest_unsent_txn(self, service):
- """Get the oldest transaction which has not been sent for this
- service.
+ async def get_oldest_unsent_txn(
+ self, service: ApplicationService
+ ) -> Optional[AppServiceTransaction]:
+ """Get the oldest transaction which has not been sent for this service.
Args:
- service(ApplicationService): The app service to get the oldest txn.
+ service: The app service to get the oldest txn.
Returns:
An AppServiceTransaction or None.
"""
@@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
)
- def _get_last_txn(self, txn, service_id):
+ def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
@@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
- async def set_appservice_last_pos(self, pos) -> None:
+ async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
- async def get_new_events_for_appservice(self, current_id, limit):
+ async def get_new_events_for_appservice(
+ self, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
@@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
- self, service: ApplicationService, type: str, pos: int
+ self, service: ApplicationService, type: str, pos: Optional[int]
) -> None:
if type not in ("read_receipt", "presence"):
raise ValueError(
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 849bd5ba7a..3e26d5ba87 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = frozendict_json_encoder.encode(
+ pruned_json = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
- pruned_json = frozendict_json_encoder.encode(
+ pruned_json = json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 87808c1483..90fb1a1f00 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -769,9 +769,7 @@ class PersistEventsStore:
logger.exception("")
raise
- metadata_json = frozendict_json_encoder.encode(
- event.internal_metadata.get_dict()
- )
+ metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -826,10 +824,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
- "internal_metadata": frozendict_json_encoder.encode(
+ "internal_metadata": json_encoder.encode(
event.internal_metadata.get_dict()
),
- "json": frozendict_json_encoder.encode(event_dict(event)),
+ "json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6e7f16f39c..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -31,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import (
@@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
return event_map
+ async def get_stripped_room_state_from_event_context(
+ self,
+ context: EventContext,
+ state_types_to_include: List[EventTypes],
+ membership_user_id: Optional[str] = None,
+ ) -> List[JsonDict]:
+ """
+ Retrieve the stripped state from a room, given an event context to retrieve state
+ from as well as the state types to include. Optionally, include the membership
+ events from a specific user.
+
+ "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+ are included from each state event.
+
+ Args:
+ context: The event context to retrieve state of the room from.
+ state_types_to_include: The type of state events to include.
+ membership_user_id: An optional user ID to include the stripped membership state
+ events of. This is useful when generating the stripped state of a room for
+ invites. We want to send membership events of the inviter, so that the
+ invitee can display the inviter's profile information if the room lacks any.
+
+ Returns:
+ A list of dictionaries, each representing a stripped state event from the room.
+ """
+ current_state_ids = await context.get_current_state_ids()
+
+ # We know this event is not an outlier, so this must be
+ # non-None.
+ assert current_state_ids is not None
+
+ # The state to include
+ state_to_include_ids = [
+ e_id
+ for k, e_id in current_state_ids.items()
+ if k[0] in state_types_to_include
+ or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+ ]
+
+ state_to_include = await self.get_events(state_to_include_ids)
+
+ return [
+ {
+ "type": e.type,
+ "state_key": e.state_key,
+ "content": e.content,
+ "sender": e.sender,
+ }
+ for e in state_to_include.values()
+ ]
+
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
@@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
+ " LEFT JOIN room_memberships USING (event_id)"
+ " LEFT JOIN rejections USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" AND instance_name = ?"
" ORDER BY stream_ordering ASC"
@@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
+ " LEFT JOIN room_memberships USING (event_id)"
+ " LEFT JOIN rejections USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" AND out.instance_name = ?"
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index daf57675d8..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
+ async def get_remote_media_thumbnail(
+ self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+ ) -> Optional[Dict[str, Any]]:
+ """Fetch the thumbnail info of given width, height and type.
+ """
+
+ return await self.db_pool.simple_select_one(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": origin,
+ "media_id": media_id,
+ "thumbnail_width": t_width,
+ "thumbnail_height": t_height,
+ "thumbnail_type": t_type,
+ },
+ retcols=(
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ "filesystem_id",
+ ),
+ allow_none=True,
+ desc="get_remote_media_thumbnail",
+ )
+
async def store_remote_media_thumbnail(
self,
origin,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e7b17a7385..e5d07ce72a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -18,6 +18,8 @@ import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+ """Result of looking up an access token.
+
+ Attributes:
+ user_id: The user that this token authenticates as
+ is_guest
+ shadow_banned
+ token_id: The ID of the access token looked up
+ device_id: The device associated with the token, if any.
+ valid_until_ms: The timestamp the token expires, if any.
+ token_owner: The "owner" of the token. This is either the same as the
+ user, or a server admin who is logged in as the user.
+ """
+
+ user_id = attr.ib(type=str)
+ is_guest = attr.ib(type=bool, default=False)
+ shadow_banned = attr.ib(type=bool, default=False)
+ token_id = attr.ib(type=Optional[int], default=None)
+ device_id = attr.ib(type=Optional[str], default=None)
+ valid_until_ms = attr.ib(type=Optional[int], default=None)
+ token_owner = attr.ib(type=str)
+
+ # Make the token owner default to the user ID, which is the common case.
+ @token_owner.default
+ def _default_token_owner(self):
+ return self.user_id
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial
@cached()
- async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+ async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
- None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
- def _query_for_auth(self, txn, token):
+ def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
- SELECT users.name,
+ SELECT users.name as user_id,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
- access_tokens.valid_until_ms
+ access_tokens.valid_until_ms,
+ access_tokens.user_id as token_owner
FROM users
- INNER JOIN access_tokens on users.name = access_tokens.user_id
+ INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
- return rows[0]
+ return TokenLookupResult(**rows[0])
return None
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
diff --git a/synapse/types.py b/synapse/types.py
index 5bde67cc07..66bb5bac8d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -29,6 +29,7 @@ from typing import (
Tuple,
Type,
TypeVar,
+ Union,
)
import attr
@@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING:
+ from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5
@@ -74,6 +76,7 @@ class Requester(
"shadow_banned",
"device_id",
"app_service",
+ "authenticated_entity",
],
)
):
@@ -104,6 +107,7 @@ class Requester(
"shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
+ "authenticated_entity": self.authenticated_entity,
}
@staticmethod
@@ -129,16 +133,18 @@ class Requester(
shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
+ authenticated_entity=input["authenticated_entity"],
)
def create_requester(
- user_id,
- access_token_id=None,
- is_guest=False,
- shadow_banned=False,
- device_id=None,
- app_service=None,
+ user_id: Union[str, "UserID"],
+ access_token_id: Optional[int] = None,
+ is_guest: Optional[bool] = False,
+ shadow_banned: Optional[bool] = False,
+ device_id: Optional[str] = None,
+ app_service: Optional["ApplicationService"] = None,
+ authenticated_entity: Optional[str] = None,
):
"""
Create a new ``Requester`` object
@@ -151,14 +157,27 @@ def create_requester(
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
+
+ if authenticated_entity is None:
+ authenticated_entity = user_id.to_string()
+
return Requester(
- user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+ user_id,
+ access_token_id,
+ is_guest,
+ shadow_banned,
+ device_id,
+ app_service,
+ authenticated_entity,
)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index d55b93d763..517686f0a6 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -18,6 +18,7 @@ import logging
import re
import attr
+from frozendict import frozendict
from twisted.internet import defer, task
@@ -31,9 +32,26 @@ def _reject_invalid_json(val):
raise ValueError("Invalid JSON value: '%s'" % val)
-# Create a custom encoder to reduce the whitespace produced by JSON encoding and
-# ensure that valid JSON is produced.
-json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
+def _handle_frozendict(obj):
+ """Helper for json_encoder. Makes frozendicts serializable by returning
+ the underlying dict
+ """
+ if type(obj) is frozendict:
+ # fishing the protected dict out of the object is a bit nasty,
+ # but we don't really want the overhead of copying the dict.
+ return obj._dict
+ raise TypeError(
+ "Object of type %s is not JSON serializable" % obj.__class__.__name__
+ )
+
+
+# A custom JSON encoder which:
+# * handles frozendicts
+# * produces valid JSON (no NaNs etc)
+# * reduces redundant whitespace
+json_encoder = json.JSONEncoder(
+ allow_nan=False, separators=(",", ":"), default=_handle_frozendict
+)
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d7fffee66..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import functools
import inspect
import logging
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
from weakref import WeakValueDictionary
from twisted.internet import defer
@@ -24,6 +37,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
- def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+ def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -97,8 +111,107 @@ class _CacheDescriptorBase:
self.add_cache_context = cache_context
+ self.cache_key_builder = get_cache_key_builder(
+ self.arg_names, self.arg_defaults
+ )
+
+
+class _LruCachedFunction(Generic[F]):
+ cache = None # type: LruCache[CacheKey, Any]
+ __call__ = None # type: F
+
+
+def lru_cache(
+ max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+ """A method decorator that applies a memoizing cache around the function.
+
+ This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+ that the signature is slightly different.
+
+ The main differences with functools.lru_cache are:
+ (a) the size of the cache can be controlled via the cache_factor mechanism
+ (b) the wrapped function can request a "cache_context" which provides a
+ callback mechanism to indicate that the result is no longer valid
+ (c) prometheus metrics are exposed automatically.
+
+ The function should take zero or more arguments, which are used as the key for the
+ cache. Single-argument functions use that argument as the cache key; otherwise the
+ arguments are built into a tuple.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example:
+
+ @lru_cache(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+ return r1 + r2
+
+ The wrapped function also has a 'cache' property which offers direct access to the
+ underlying LruCache.
+ """
+
+ def func(orig: F) -> _LruCachedFunction[F]:
+ desc = LruCacheDescriptor(
+ orig, max_entries=max_entries, cache_context=cache_context,
+ )
+ return cast(_LruCachedFunction[F], desc)
+
+ return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+ """Helper for @lru_cache"""
+
+ class _Sentinel(enum.Enum):
+ sentinel = object()
+
+ def __init__(
+ self, orig, max_entries: int = 1000, cache_context: bool = False,
+ ):
+ super().__init__(orig, num_args=None, cache_context=cache_context)
+ self.max_entries = max_entries
+
+ def __get__(self, obj, owner):
+ cache = LruCache(
+ cache_name=self.orig.__name__, max_size=self.max_entries,
+ ) # type: LruCache[CacheKey, Any]
+
+ get_cache_key = self.cache_key_builder
+ sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+ @functools.wraps(self.orig)
+ def _wrapped(*args, **kwargs):
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+ callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+ cache_key = get_cache_key(args, kwargs)
-class CacheDescriptor(_CacheDescriptorBase):
+ ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+ if ret != sentinel:
+ return ret
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+ ret2 = self.orig(obj, *args, **kwargs)
+ cache.set(cache_key, ret2, callbacks=callbacks)
+
+ return ret2
+
+ wrapped = cast(_CachedFunction, _wrapped)
+ wrapped.cache = cache
+ obj.__dict__[self.orig.__name__] = wrapped
+
+ return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=False,
iterable=False,
):
-
super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries
@@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any]
- def get_cache_key_gen(args, kwargs):
- """Given some args/kwargs return a generator that resolves into
- the cache_key.
-
- We loop through each arg name, looking up if its in the `kwargs`,
- otherwise using the next argument in `args`. If there are no more
- args then we try looking the arg name up in the defaults
- """
- pos = 0
- for nm in self.arg_names:
- if nm in kwargs:
- yield kwargs[nm]
- elif pos < len(args):
- yield args[pos]
- pos += 1
- else:
- yield self.arg_defaults[nm]
-
- # By default our cache key is a tuple, but if there is only one item
- # then don't bother wrapping in a tuple. This is to save memory.
- if self.num_args == 1:
- nm = self.arg_names[0]
-
- def get_cache_key(args, kwargs):
- if nm in kwargs:
- return kwargs[nm]
- elif len(args):
- return args[0]
- else:
- return self.arg_defaults[nm]
-
- else:
-
- def get_cache_key(args, kwargs):
- return tuple(get_cache_key_gen(args, kwargs))
+ get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
@@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
@@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return wrapped
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
@@ -382,11 +459,13 @@ class _CacheContext:
on a lower level.
"""
+ Cache = Union[DeferredCache, LruCache]
+
_cache_context_objects = (
WeakValueDictionary()
- ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
+ ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
- def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
+ def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache
self._cache_key = cache_key
@@ -396,8 +475,8 @@ class _CacheContext:
@classmethod
def get_instance(
- cls, cache, cache_key
- ): # type: (DeferredCache, CacheKey) -> _CacheContext
+ cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+ ) -> "_CacheContext":
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.
@@ -418,7 +497,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
- func = lambda orig: CacheDescriptor(
+ func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@@ -460,7 +539,7 @@ def cachedList(
def batch_do_something(self, first_arg, second_args):
...
"""
- func = lambda orig: CacheListDescriptor(
+ func = lambda orig: DeferredCacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
@@ -468,3 +547,65 @@ def cachedList(
)
return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+ param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+ """Construct a function which will build cache keys suitable for a cached function
+
+ Args:
+ param_names: list of formal parameter names for the cached function
+ param_defaults: a mapping from parameter name to default value for that param
+
+ Returns:
+ A function which will take an (args, kwargs) pair and return a cache key
+ """
+
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+
+ if len(param_names) == 1:
+ nm = param_names[0]
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return param_defaults[nm]
+
+ else:
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+ return get_cache_key
+
+
+def _get_cache_key_gen(
+ param_names: Iterable[str],
+ param_defaults: Mapping[str, Any],
+ args: Sequence[Any],
+ kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ This is essentially the same operation as `inspect.getcallargs`, but optimised so
+ that we don't need to inspect the target function for each call.
+ """
+
+ # We loop through each arg name, looking up if its in the `kwargs`,
+ # otherwise using the next argument in `args`. If there are no more
+ # args then we try looking the arg name up in the defaults.
+ pos = 0
+ for nm in param_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield param_defaults[nm]
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index bf094c9386..5f7a6dd1d3 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-
from frozendict import frozendict
@@ -49,23 +47,3 @@ def unfreeze(o):
pass
return o
-
-
-def _handle_frozendict(obj):
- """Helper for EventEncoder. Makes frozendicts serializable by returning
- the underlying dict
- """
- if type(obj) is frozendict:
- # fishing the protected dict out of the object is a bit nasty,
- # but we don't really want the overhead of copying the dict.
- return obj._dict
- raise TypeError(
- "Object of type %s is not JSON serializable" % obj.__class__.__name__
- )
-
-
-# A JSONEncoder which is capable of encoding frozendicts without barfing.
-# Additionally reduce the whitespace produced by JSON encoding.
-frozendict_json_encoder = json.JSONEncoder(
- allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
-)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index a5cc9d0551..4ab379e429 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
failure_ts,
retry_interval,
backoff_on_failure=backoff_on_failure,
- **kwargs
+ **kwargs,
)
diff --git a/synmark/__init__.py b/synmark/__init__.py
index 09bc7e7927..3d4ec3e184 100644
--- a/synmark/__init__.py
+++ b/synmark/__init__.py
@@ -21,45 +21,6 @@ except ImportError:
from twisted.internet.pollreactor import PollReactor as Reactor
from twisted.internet.main import installReactor
-from synapse.config.homeserver import HomeServerConfig
-from synapse.util import Clock
-
-from tests.utils import default_config, setup_test_homeserver
-
-
-async def make_homeserver(reactor, config=None):
- """
- Make a Homeserver suitable for running benchmarks against.
-
- Args:
- reactor: A Twisted reactor to run under.
- config: A HomeServerConfig to use, or None.
- """
- cleanup_tasks = []
- clock = Clock(reactor)
-
- if not config:
- config = default_config("test")
-
- config_obj = HomeServerConfig()
- config_obj.parse_config_dict(config, "", "")
-
- hs = setup_test_homeserver(
- cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
- )
- stor = hs.get_datastore()
-
- # Run the database background updates.
- if hasattr(stor.db_pool.updates, "do_next_background_update"):
- while not await stor.db_pool.updates.has_completed_background_updates():
- await stor.db_pool.updates.do_next_background_update(1)
-
- def cleanup():
- for i in cleanup_tasks:
- i()
-
- return hs, clock.sleep, cleanup
-
def make_reactor():
"""
diff --git a/synmark/__main__.py b/synmark/__main__.py
index 17df9ddeb7..de13c1a909 100644
--- a/synmark/__main__.py
+++ b/synmark/__main__.py
@@ -12,20 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import sys
from argparse import REMAINDER
from contextlib import redirect_stderr
from io import StringIO
import pyperf
-from synmark import make_reactor
-from synmark.suites import SUITES
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.logger import globalLogBeginner, textFileLogObserver
from twisted.python.failure import Failure
+from synmark import make_reactor
+from synmark.suites import SUITES
+
from tests.utils import setupdb
diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py
index d8e4c7d58f..c9d9cf761e 100644
--- a/synmark/suites/logging.py
+++ b/synmark/suites/logging.py
@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import warnings
from io import StringIO
from mock import Mock
from pyperf import perf_counter
-from synmark import make_homeserver
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ServerFactory
-from twisted.logger import LogBeginner, Logger, LogPublisher
+from twisted.logger import LogBeginner, LogPublisher
from twisted.protocols.basic import LineOnlyReceiver
-from synapse.logging._structured import setup_structured_logging
+from synapse.config.logger import _setup_stdlib_logging
+from synapse.logging import RemoteHandler
+from synapse.util import Clock
class LineCounter(LineOnlyReceiver):
@@ -62,7 +64,15 @@ async def main(reactor, loops):
logger_factory.on_done = Deferred()
port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1")
- hs, wait, cleanup = await make_homeserver(reactor)
+ # A fake homeserver config.
+ class Config:
+ server_name = "synmark-" + str(loops)
+ no_redirect_stdio = True
+
+ hs_config = Config()
+
+ # To be able to sleep.
+ clock = Clock(reactor)
errors = StringIO()
publisher = LogPublisher()
@@ -72,47 +82,49 @@ async def main(reactor, loops):
)
log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
+ "version": 1,
+ "loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}},
+ "formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}},
+ "handlers": {
"tersejson": {
- "type": "network_json_terse",
+ "class": "synapse.logging.RemoteHandler",
"host": "127.0.0.1",
"port": port.getHost().port,
"maximum_buffer": 100,
+ "_reactor": reactor,
}
},
}
- logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher)
- logging_system = setup_structured_logging(
- hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False
+ logger = logging.getLogger("synapse.logging.test_terse_json")
+ _setup_stdlib_logging(
+ hs_config, log_config, logBeginner=beginner,
)
# Wait for it to connect...
- await logging_system._observers[0]._service.whenConnected()
+ for handler in logging.getLogger("synapse").handlers:
+ if isinstance(handler, RemoteHandler):
+ break
+ else:
+ raise RuntimeError("Improperly configured: no RemoteHandler found.")
+
+ await handler._service.whenConnected()
start = perf_counter()
# Send a bunch of useful messages
for i in range(0, loops):
- logger.info("test message %s" % (i,))
-
- if (
- len(logging_system._observers[0]._buffer)
- == logging_system._observers[0].maximum_buffer
- ):
- while (
- len(logging_system._observers[0]._buffer)
- > logging_system._observers[0].maximum_buffer / 2
- ):
- await wait(0.01)
+ logger.info("test message %s", i)
+
+ if len(handler._buffer) == handler.maximum_buffer:
+ while len(handler._buffer) > handler.maximum_buffer / 2:
+ await clock.sleep(0.01)
await logger_factory.on_done
end = perf_counter() - start
- logging_system.stop()
+ handler.close()
port.stopListening()
- cleanup()
return end
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cb6f29d670..0fd55f428a 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
@@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
+ user_info = TokenLookupResult(
+ user_id=self.test_user, token_id=5, device_id="device"
+ )
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto"}
+ user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
- {"name": "@baldrick:matrix.org", "device_id": "device"}
+ TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
@@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
- user = user_info["user"]
- self.assertEqual(UserID.from_string(user_id), user)
+ self.assertEqual(user_id, user_info.user_id)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
- self.assertEqual(user_info["device_id"], "device")
+ self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
@@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
- user = user_info["user"]
- is_guest = user_info["is_guest"]
- self.assertEqual(UserID.from_string(user_id), user)
- self.assertTrue(is_guest)
+ self.assertEqual(user_id, user_info.user_id)
+ self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
@@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok:
return defer.succeed(None)
return defer.succeed(
- {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ TokenLookupResult(
+ user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
+ )
)
self.store.get_user_by_access_token = get_user
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=True,
+ None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=False,
+ None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
+ sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 4512c51311..875aaec2c6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
- self.assertEqual(user_info["device_id"], retrieved_device_id)
+ self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9f6f21a6e2..2e0fea04af 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
)
- self.token_id = self.info["token_id"]
+ self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index e69de29bb2..a58d51441c 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+
+class LoggerCleanupMixin:
+ def get_logger(self, handler):
+ """
+ Attach a handler to a logger and add clean-ups to remove revert this.
+ """
+ # Create a logger and add the handler to it.
+ logger = logging.getLogger(__name__)
+ logger.addHandler(handler)
+
+ # Ensure the logger actually logs something.
+ logger.setLevel(logging.INFO)
+
+ # Ensure the logger gets cleaned-up appropriately.
+ self.addCleanup(logger.removeHandler, handler)
+ self.addCleanup(logger.setLevel, logging.NOTSET)
+
+ return logger
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
new file mode 100644
index 0000000000..4bc27a1d7d
--- /dev/null
+++ b/tests/logging/test_remote_handler.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import AccumulatingProtocol
+
+from synapse.logging import RemoteHandler
+
+from tests.logging import LoggerCleanupMixin
+from tests.server import FakeTransport, get_clock
+from tests.unittest import TestCase
+
+
+def connect_logging_client(reactor, client_id):
+ # This is essentially tests.server.connect_client, but disabling autoflush on
+ # the client transport. This is necessary to avoid an infinite loop due to
+ # sending of data via the logging transport causing additional logs to be
+ # written.
+ factory = reactor.tcpClients.pop(client_id)[2]
+ client = factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, reactor))
+ client.makeConnection(FakeTransport(server, reactor, autoflush=False))
+
+ return client, server
+
+
+class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.reactor, _ = get_clock()
+
+ def test_log_output(self):
+ """
+ The remote handler delivers logs over TCP.
+ """
+ handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor)
+ logger = self.get_logger(handler)
+
+ logger.info("Hello there, %s!", "wally")
+
+ # Trigger the connection
+ client, server = connect_logging_client(self.reactor, 0)
+
+ # Trigger data being sent
+ client.transport.flush()
+
+ # One log message, with a single trailing newline
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(server.data.count(b"\n"), 1)
+
+ # Ensure the data passed through properly.
+ self.assertEqual(logs[0], "Hello there, wally!")
+
+ def test_log_backpressure_debug(self):
+ """
+ When backpressure is hit, DEBUG logs will be shed.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 7):
+ logger.info("info %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # Only the 7 infos made it through, the debugs were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 7)
+ self.assertNotIn(b"debug", server.data)
+
+ def test_log_backpressure_info(self):
+ """
+ When backpressure is hit, DEBUG and INFO logs will be shed.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 10):
+ logger.warning("warn %s" % (i,))
+
+ # Send a bunch of info messages
+ for i in range(0, 3):
+ logger.info("info %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # The 10 warnings made it through, the debugs and infos were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 10)
+ self.assertNotIn(b"debug", server.data)
+ self.assertNotIn(b"info", server.data)
+
+ def test_log_backpressure_cut_middle(self):
+ """
+ When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
+ it will cut the middle messages out.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send a bunch of useful messages
+ for i in range(0, 20):
+ logger.warning("warn %s" % (i,))
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # The first five and last five warnings made it through, the debugs and
+ # infos were elided
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(
+ ["warn %s" % (i,) for i in range(5)]
+ + ["warn %s" % (i,) for i in range(15, 20)],
+ logs,
+ )
+
+ def test_cancel_connection(self):
+ """
+ Gracefully handle the connection being cancelled.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send a message.
+ logger.info("Hello there, %s!", "wally")
+
+ # Do not accept the connection and shutdown. This causes the pending
+ # connection to be cancelled (and should not raise any exceptions).
+ handler.close()
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
deleted file mode 100644
index d36f5f426c..0000000000
--- a/tests/logging/test_structured.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import os.path
-import shutil
-import sys
-import textwrap
-
-from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
-
-from synapse.config.logger import setup_logging
-from synapse.logging._structured import setup_structured_logging
-from synapse.logging.context import LoggingContext
-
-from tests.unittest import DEBUG, HomeserverTestCase
-
-
-class FakeBeginner:
- def beginLoggingTo(self, observers, **kwargs):
- self.observers = observers
-
-
-class StructuredLoggingTestBase:
- """
- Test base that registers a cleanup handler to reset the stdlib log handler
- to 'unset'.
- """
-
- def prepare(self, reactor, clock, hs):
- def _cleanup():
- logging.getLogger("synapse").setLevel(logging.NOTSET)
-
- self.addCleanup(_cleanup)
-
-
-class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
- """
- Tests for Synapse's structured logging support.
- """
-
- def test_output_to_json_round_trip(self):
- """
- Synapse logs can be outputted to JSON and then read back again.
- """
- temp_dir = self.mktemp()
- os.mkdir(temp_dir)
- self.addCleanup(shutil.rmtree, temp_dir)
-
- json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
-
- log_config = {
- "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Read the log file and check it has the event we sent
- with open(json_log_file, "r") as f:
- logged_events = list(eventsFromJSONLogFile(f))
- self.assertEqual(len(logged_events), 1)
-
- # The event pulled from the file should render fine
- self.assertEqual(
- eventAsText(logged_events[0], includeTimestamp=False),
- "[tests.logging.test_structured#info] Hello there, wally!",
- )
-
- def test_output_to_text(self):
- """
- Synapse logs can be outputted to text.
- """
- temp_dir = self.mktemp()
- os.mkdir(temp_dir)
- self.addCleanup(shutil.rmtree, temp_dir)
-
- log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
-
- log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Read the log file and check it has the event we sent
- with open(log_file, "r") as f:
- logged_events = f.read().strip().split("\n")
- self.assertEqual(len(logged_events), 1)
-
- # The event pulled from the file should render fine
- self.assertTrue(
- logged_events[0].endswith(
- " - tests.logging.test_structured - INFO - None - Hello there, wally!"
- )
- )
-
- def test_collects_logcontext(self):
- """
- Test that log outputs have the attached logging context.
- """
- log_config = {"drains": {}}
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- publisher = setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- logs = []
-
- publisher.addObserver(logs.append)
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
-
- with LoggingContext("testcontext", request="somereq"):
- logger.info("Hello there, {name}!", name="steve")
-
- self.assertEqual(len(logs), 1)
- self.assertEqual(logs[0]["request"], "somereq")
-
-
-class StructuredLoggingConfigurationFileTestCase(
- StructuredLoggingTestBase, HomeserverTestCase
-):
- def make_homeserver(self, reactor, clock):
-
- tempdir = self.mktemp()
- os.mkdir(tempdir)
- log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
- self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
-
- config = self.default_config()
- config["log_config"] = log_config_file
-
- with open(log_config_file, "w") as f:
- f.write(
- textwrap.dedent(
- """\
- structured: true
-
- drains:
- file:
- type: file_json
- location: %s
- """
- % (self.homeserver_log,)
- )
- )
-
- self.addCleanup(self._sys_cleanup)
-
- return self.setup_test_homeserver(config=config)
-
- def _sys_cleanup(self):
- sys.stdout = sys.__stdout__
- sys.stderr = sys.__stderr__
-
- # Do not remove! We need the logging system to be set other than WARNING.
- @DEBUG
- def test_log_output(self):
- """
- When a structured logging config is given, Synapse will use it.
- """
- beginner = FakeBeginner()
- publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
-
- # Make a logger and send an event
- logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
-
- with LoggingContext("testcontext", request="somereq"):
- logger.info("Hello there, {name}!", name="steve")
-
- with open(self.homeserver_log, "r") as f:
- logged_events = [
- eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
- ]
-
- logs = "\n".join(logged_events)
- self.assertTrue("***** STARTING SERVER *****" in logs)
- self.assertTrue("Hello there, steve!" in logs)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index fd128b88e0..73f469b802 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -14,57 +14,33 @@
# limitations under the License.
import json
-from collections import Counter
+import logging
+from io import StringIO
-from twisted.logger import Logger
+from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
-from synapse.logging._structured import setup_structured_logging
+from tests.logging import LoggerCleanupMixin
+from tests.unittest import TestCase
-from tests.server import connect_client
-from tests.unittest import HomeserverTestCase
-from .test_structured import FakeBeginner, StructuredLoggingTestBase
-
-
-class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
- def test_log_output(self):
+class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+ def test_terse_json_output(self):
"""
- The Terse JSON outputter delivers simplified structured logs over TCP.
+ The Terse JSON formatter converts log messages to JSON.
"""
- log_config = {
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- }
- }
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- logger = Logger(
- namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Trigger the connection
- self.pump()
+ output = StringIO()
- _, server = connect_client(self.reactor, 0)
+ handler = logging.StreamHandler(output)
+ handler.setFormatter(TerseJsonFormatter())
+ logger = self.get_logger(handler)
- # Trigger data being sent
- self.pump()
+ logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline
- logs = server.data.decode("utf8").splitlines()
+ # One log message, with a single trailing newline.
+ data = output.getvalue()
+ logs = data.splitlines()
self.assertEqual(len(logs), 1)
- self.assertEqual(server.data.count(b"\n"), 1)
-
+ self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
# The terse logger should give us these keys.
@@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"log",
"time",
"level",
- "log_namespace",
- "request",
- "scope",
- "server_name",
- "name",
+ "namespace",
]
self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
- # It contains the data we expect.
- self.assertEqual(log["name"], "wally")
-
- def test_log_backpressure_debug(self):
+ def test_extra_data(self):
"""
- When backpressure is hit, DEBUG logs will be shed.
+ Additional information can be included in the structured logging.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
-
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
- )
+ output = StringIO()
- # Send some debug messages
- for i in range(0, 3):
- logger.debug("debug %s" % (i,))
+ handler = logging.StreamHandler(output)
+ handler.setFormatter(TerseJsonFormatter())
+ logger = self.get_logger(handler)
- # Send a bunch of useful messages
- for i in range(0, 7):
- logger.info("test message %s" % (i,))
-
- # The last debug message pushes it past the maximum buffer
- logger.debug("too much debug")
-
- # Allow the reconnection
- _, server = connect_client(self.reactor, 0)
- self.pump()
-
- # Only the 7 infos made it through, the debugs were elided
- logs = server.data.splitlines()
- self.assertEqual(len(logs), 7)
-
- def test_log_backpressure_info(self):
- """
- When backpressure is hit, DEBUG and INFO logs will be shed.
- """
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
-
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
+ logger.info(
+ "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
- # Send some debug messages
- for i in range(0, 3):
- logger.debug("debug %s" % (i,))
-
- # Send a bunch of useful messages
- for i in range(0, 10):
- logger.warn("test warn %s" % (i,))
-
- # Send a bunch of info messages
- for i in range(0, 3):
- logger.info("test message %s" % (i,))
-
- # The last debug message pushes it past the maximum buffer
- logger.debug("too much debug")
-
- # Allow the reconnection
- client, server = connect_client(self.reactor, 0)
- self.pump()
+ # One log message, with a single trailing newline.
+ data = output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ log = json.loads(logs[0])
- # The 10 warnings made it through, the debugs and infos were elided
- logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
- self.assertEqual(len(logs), 10)
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "time",
+ "level",
+ "namespace",
+ # The additional keys given via extra.
+ "foo",
+ "int",
+ "bool",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
- self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
+ # Check the values of the extra fields.
+ self.assertEqual(log["foo"], "bar")
+ self.assertEqual(log["int"], 3)
+ self.assertIs(log["bool"], True)
- def test_log_backpressure_cut_middle(self):
+ def test_json_output(self):
"""
- When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
- it will cut the middle messages out.
+ The Terse JSON formatter converts log messages to JSON.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
+ output = StringIO()
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
- )
+ handler = logging.StreamHandler(output)
+ handler.setFormatter(JsonFormatter())
+ logger = self.get_logger(handler)
- # Send a bunch of useful messages
- for i in range(0, 20):
- logger.warn("test warn", num=i)
+ logger.info("Hello there, %s!", "wally")
- # Allow the reconnection
- client, server = connect_client(self.reactor, 0)
- self.pump()
+ # One log message, with a single trailing newline.
+ data = output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ log = json.loads(logs[0])
- # The first five and last five warnings made it through, the debugs and
- # infos were elided
- logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
- self.assertEqual(len(logs), 10)
- self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
- self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index d9993e6245..bcdcafa5a9 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..8571924b29 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 093e2faac7..5c633ac6df 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,6 @@ import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
-import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@@ -39,12 +38,22 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport, render
+try:
+ import hiredis
+except ImportError:
+ hiredis = None
+
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ # hiredis is an optional dependency so we don't want to require it for running
+ # the tests.
+ if not hiredis:
+ skip = "Requires hiredis"
+
servlets = [
streams.register_servlets,
]
@@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
- **kwargs
+ **kwargs,
)
# If the instance is in the `instance_map` config then workers may try
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
sender=sender,
type="test_event",
content={"body": body},
- **kwargs
+ **kwargs,
)
)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
new file mode 100644
index 0000000000..77c261dbf7
--- /dev/null
+++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,277 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from binascii import unhexlify
+from typing import Tuple
+
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+from twisted.web.server import Request
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.server import HomeServer
+
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, FakeTransport
+
+logger = logging.getLogger(__name__)
+
+test_server_connection_factory = None
+
+
+class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks running multiple media repos work correctly.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ self.reactor.lookups["example.com"] = "127.0.0.2"
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ return conf
+
+ def _get_media_req(
+ self, hs: HomeServer, target: str, media_id: str
+ ) -> Tuple[FakeChannel, Request]:
+ """Request some remote media from the given HS by calling the download
+ API.
+
+ This then triggers an outbound request from the HS to the target.
+
+ Returns:
+ The channel for the *client* request and the *outbound* request for
+ the media which the caller should respond to.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/{}/{}".format(target, media_id),
+ shorthand=False,
+ access_token=self.access_token,
+ )
+ request.render(hs.get_media_repository_resource().children[b"download"])
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+ # build the test server
+ server_tls_protocol = _build_test_server(get_connection_factory())
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_server = server_tls_protocol.wrappedProtocol
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(
+ request.path,
+ "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+ )
+
+ return channel, request
+
+ def test_basic(self):
+ """Test basic fetching of remote media from a single worker.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+ channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+ request.setResponseCode(200)
+ request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request.write(b"Hello!")
+ request.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], b"Hello!")
+
+ def test_download_simple_file_race(self):
+ """Test that fetching remote media from two different processes at the
+ same time works.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_media()
+
+ # Make two requests without responding to the outbound media requests.
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+ # Respond to the first outbound media request and check that the client
+ # request is successful
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request1.write(b"Hello!")
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], b"Hello!")
+
+ # Now respond to the second with the same content.
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request2.write(b"Hello!")
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], b"Hello!")
+
+ # We expect only one new file to have been persisted.
+ self.assertEqual(start_count + 1, self._count_remote_media())
+
+ def test_download_image_race(self):
+ """Test that fetching remote *images* from two different processes at
+ the same time works.
+
+ This checks that races generating thumbnails are handled correctly.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_thumbnails()
+
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+ png_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request1.write(png_data)
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], png_data)
+
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request2.write(png_data)
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], png_data)
+
+ # We expect only three new thumbnails to have been persisted.
+ self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+ def _count_remote_media(self) -> int:
+ """Count the number of files in our remote media directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_content"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+ def _count_remote_thumbnails(self) -> int:
+ """Count the number of files in our remote thumbnails directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+
+def get_connection_factory():
+ # this needs to happen once, but not until we are ready to run the first test
+ global test_server_connection_factory
+ if test_server_connection_factory is None:
+ test_server_connection_factory = TestServerTLSConnectionFactory(
+ sanlist=[b"DNS:example.com"]
+ )
+ return test_server_connection_factory
+
+
+def _build_test_server(connection_creator):
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Args:
+ connection_creator (IOpenSSLServerConnectionCreator): thing to build
+ SSL connections
+ sanlist (list[bytes]): list of the SAN entries for the cert returned
+ by the server
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=server_factory
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..67c27a089f 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_dict["token_id"]
+ token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e815b92329..7df32e5093 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1118,6 +1118,130 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
+class PushersRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list pushers of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_pushers(self):
+ """
+ Tests that a normal lookup for pushers is successfully
+ """
+
+ # Get pushers
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ # Register the pusher
+ other_user_token = self.login("user", "pass")
+ user_tuple = self.get_success(
+ self.store.get_user_by_access_token(other_user_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=self.other_user,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Get pushers
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+
+ for p in channel.json_body["pushers"]:
+ self.assertIn("pushkey", p)
+ self.assertIn("kind", p)
+ self.assertIn("app_id", p)
+ self.assertIn("app_display_name", p)
+ self.assertIn("device_display_name", p)
+ self.assertIn("profile_tag", p)
+ self.assertIn("lang", p)
+ self.assertIn("url", p["data"])
+
+
class UserMediaRestTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..98c3887bbf 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
)
self.hs.get_datastore().services_cache.append(appservice)
diff --git a/tests/server.py b/tests/server.py
index 4d33b84097..3dd2cfc072 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -46,7 +46,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
- result = attr.ib(default=attr.Factory(dict))
+ result = attr.ib(type=dict, default=attr.Factory(dict))
_producer = None
@property
@@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runWithConnection,
func,
*args,
- **kwargs
+ **kwargs,
)
def runInteraction(interaction, *args, **kwargs):
@@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runInteraction,
interaction,
*args,
- **kwargs
+ **kwargs,
)
pool.runWithConnection = runWithConnection
@@ -571,12 +571,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol
reactor
factory: The connecting factory to build.
"""
- factory = reactor.tcpClients[client_id][2]
+ factory = reactor.tcpClients.pop(client_id)[2]
client = factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, reactor))
client.makeConnection(FakeTransport(server, reactor))
- reactor.tcpClients.pop(client_id)
-
return client, server
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..e96ca1c8ca 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
"GET",
"/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token,
- **make_request_args
+ **make_request_args,
)
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1])
)
- self.assertDictContainsSubset(
- {"name": self.user_id, "device_id": self.device_id}, result
- )
-
- self.assertTrue("token_id" in result)
+ self.assertEqual(result.user_id, self.user_id)
+ self.assertEqual(result.device_id, self.device_id)
+ self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
- self.assertEqual(self.user_id, user["name"])
+ self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
"""
Utilities for running the unit tests
"""
+import sys
+import warnings
from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
TV = TypeVar("TV")
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
future = Future() # type: ignore
future.set_result(result)
return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+ """
+ Convert warnings from a non-awaited coroutines into errors.
+ """
+ warnings.simplefilter("error", RuntimeWarning)
+
+ # unraisablehook was added in Python 3.8.
+ if not hasattr(sys, "unraisablehook"):
+ return lambda: None
+
+ # State shared between unraisablehook and check_for_unraisable_exceptions.
+ unraisable_exceptions = []
+ orig_unraisablehook = sys.unraisablehook # type: ignore
+
+ def unraisablehook(unraisable):
+ unraisable_exceptions.append(unraisable.exc_value)
+
+ def cleanup():
+ """
+ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+ """
+ sys.unraisablehook = orig_unraisablehook # type: ignore
+ if unraisable_exceptions:
+ raise unraisable_exceptions.pop()
+
+ sys.unraisablehook = unraisablehook # type: ignore
+
+ return cleanup
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e93aa84405..c3c4a93e1f 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -50,7 +50,7 @@ async def inject_member_event(
sender=sender,
state_key=target,
content=content,
- **kwargs
+ **kwargs,
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 257f465897..08cf9b10c5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -54,7 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
+ # Trial messes with the warnings configuration, thus this has to be
+ # done in the context of an individual TestCase.
+ self.addCleanup(setup_awaitable_errors())
+
return orig()
@around(self)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
from tests import unittest
+from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
+class LruCacheDecoratorTestCase(unittest.TestCase):
+ def test_base(self):
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @lru_cache()
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
+
+
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
@@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_invalidate_cascade(self):
+ """Invalidations should cascade up through cache contexts"""
+
+ class Cls:
+ @cached(cache_context=True)
+ async def func1(self, key, cache_context):
+ return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+ @cached(cache_context=True)
+ async def func2(self, key, cache_context):
+ return self.func3(key, on_invalidate=cache_context.invalidate)
+
+ @lru_cache(cache_context=True)
+ def func3(self, key, cache_context):
+ self.invalidate = cache_context.invalidate
+ return 42
+
+ obj = Cls()
+
+ top_invalidate = mock.Mock()
+ r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+ self.assertEqual(r, 42)
+ obj.invalidate()
+ top_invalidate.assert_called_once()
+
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
diff --git a/tox.ini b/tox.ini
index 6dcc439a40..c232676826 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort
+envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort
[base]
extras = test
@@ -24,11 +24,6 @@ deps =
pip>=10
setenv =
- # we have a pyproject.toml, but don't want pip to use it for building.
- # (otherwise we get an error about 'editable mode is not supported for
- # pyproject.toml-style projects').
- PIP_USE_PEP517 = false
-
PYTHONDONTWRITEBYTECODE = no_byte_code
COVERAGE_PROCESS_START = {toxinidir}/.coveragerc
|