diff options
193 files changed, 3400 insertions, 1631 deletions
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1ead0d0030..8939fda67d 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,3 +5,4 @@ * [ ] Pull request is based on the develop branch * [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog) * [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off) +* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style)) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index a71a4a696b..df81f6e54f 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -58,10 +58,29 @@ All Matrix projects have a well-defined code-style - and sometimes we've even got as far as documenting it... For instance, synapse's code style doc lives at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md. +To facilitate meeting these criteria you can run ``scripts-dev/lint.sh`` +locally. Since this runs the tools listed in the above document, you'll need +python 3.6 and to install each tool. **Note that the script does not just +test/check, but also reformats code, so you may wish to ensure any new code is +committed first**. By default this script checks all files and can take some +time; if you alter only certain files, you might wish to specify paths as +arguments to reduce the run-time. + Please ensure your changes match the cosmetic style of the existing project, and **never** mix cosmetic and functional changes in the same commit, as it makes it horribly hard to review otherwise. +Before doing a commit, ensure the changes you've made don't produce +linting errors. You can do this by running the linters as follows. Ensure to +commit any files that were corrected. + +:: + # Install the dependencies + pip install -U black flake8 isort + + # Run the linter script + ./scripts-dev/lint.sh + Changelog ~~~~~~~~~ diff --git a/INSTALL.md b/INSTALL.md index 69e423923b..e7b429c05d 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -413,16 +413,18 @@ For a more detailed guide to configuring your server for federation, see ## Email -It is desirable for Synapse to have the capability to send email. For example, -this is required to support the 'password reset' feature. +It is desirable for Synapse to have the capability to send email. This allows +Synapse to send password reset emails, send verifications when an email address +is added to a user's account, and send email notifications to users when they +receive new messages. To configure an SMTP server for Synapse, modify the configuration section -headed ``email``, and be sure to have at least the ``smtp_host``, ``smtp_port`` -and ``notif_from`` fields filled out. You may also need to set ``smtp_user``, -``smtp_pass``, and ``require_transport_security``. +headed `email`, and be sure to have at least the `smtp_host`, `smtp_port` +and `notif_from` fields filled out. You may also need to set `smtp_user`, +`smtp_pass`, and `require_transport_security`. -If Synapse is not configured with an SMTP server, password reset via email will - be disabled by default. +If email is not configured, password reset, registration and notifications via +email will be disabled. ## Registering a user diff --git a/changelog.d/5727.feature b/changelog.d/5727.feature new file mode 100644 index 0000000000..819bebf2d7 --- /dev/null +++ b/changelog.d/5727.feature @@ -0,0 +1 @@ +Add federation support for cross-signing. diff --git a/changelog.d/6164.doc b/changelog.d/6164.doc new file mode 100644 index 0000000000..f9395b02b3 --- /dev/null +++ b/changelog.d/6164.doc @@ -0,0 +1 @@ +Contributor documentation now mentions script to run linters. diff --git a/changelog.d/6232.bugfix b/changelog.d/6232.bugfix new file mode 100644 index 0000000000..12718ba934 --- /dev/null +++ b/changelog.d/6232.bugfix @@ -0,0 +1 @@ +Remove a room from a server's public rooms list on room upgrade. \ No newline at end of file diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature new file mode 100644 index 0000000000..d225ac33b6 --- /dev/null +++ b/changelog.d/6238.feature @@ -0,0 +1 @@ +Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars. diff --git a/changelog.d/6240.misc b/changelog.d/6240.misc new file mode 100644 index 0000000000..0b3d7a14a1 --- /dev/null +++ b/changelog.d/6240.misc @@ -0,0 +1 @@ +Move `persist_events` out from main data store. diff --git a/changelog.d/6254.bugfix b/changelog.d/6254.bugfix new file mode 100644 index 0000000000..3181484b88 --- /dev/null +++ b/changelog.d/6254.bugfix @@ -0,0 +1 @@ +Make notification of cross-signing signatures work with workers. diff --git a/changelog.d/6257.doc b/changelog.d/6257.doc new file mode 100644 index 0000000000..e985afde0e --- /dev/null +++ b/changelog.d/6257.doc @@ -0,0 +1 @@ +Modify CAPTCHA_SETUP.md to update the terms `private key` and `public key` to `secret key` and `site key` respectively. Contributed by Yash Jipkate. diff --git a/changelog.d/6259.misc b/changelog.d/6259.misc new file mode 100644 index 0000000000..3ff81b1ac7 --- /dev/null +++ b/changelog.d/6259.misc @@ -0,0 +1 @@ +Expose some homeserver functionality to spam checkers. diff --git a/changelog.d/6269.misc b/changelog.d/6269.misc new file mode 100644 index 0000000000..9fd333cc89 --- /dev/null +++ b/changelog.d/6269.misc @@ -0,0 +1 @@ +Fix incorrect comment regarding the functionality of an `if` statement. \ No newline at end of file diff --git a/changelog.d/6270.misc b/changelog.d/6270.misc new file mode 100644 index 0000000000..d1c5811323 --- /dev/null +++ b/changelog.d/6270.misc @@ -0,0 +1 @@ +Update CI to run `isort` over the `scripts` and `scripts-dev` directories. \ No newline at end of file diff --git a/changelog.d/6271.misc b/changelog.d/6271.misc new file mode 100644 index 0000000000..2369760272 --- /dev/null +++ b/changelog.d/6271.misc @@ -0,0 +1 @@ +Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. \ No newline at end of file diff --git a/changelog.d/6272.doc b/changelog.d/6272.doc new file mode 100644 index 0000000000..232180bcdc --- /dev/null +++ b/changelog.d/6272.doc @@ -0,0 +1 @@ +Update `INSTALL.md` Email section to talk about `account_threepid_delegates`. \ No newline at end of file diff --git a/changelog.d/6273.doc b/changelog.d/6273.doc new file mode 100644 index 0000000000..21a41d987d --- /dev/null +++ b/changelog.d/6273.doc @@ -0,0 +1 @@ +Fix a small typo in `account_threepid_delegates` configuration option. \ No newline at end of file diff --git a/changelog.d/6274.misc b/changelog.d/6274.misc new file mode 100644 index 0000000000..eb4966124f --- /dev/null +++ b/changelog.d/6274.misc @@ -0,0 +1 @@ +Port replication http server endpoints to async/await. diff --git a/changelog.d/6275.misc b/changelog.d/6275.misc new file mode 100644 index 0000000000..f57e2c4adb --- /dev/null +++ b/changelog.d/6275.misc @@ -0,0 +1 @@ +Port room rest handlers to async/await. diff --git a/changelog.d/6277.misc b/changelog.d/6277.misc new file mode 100644 index 0000000000..490713577f --- /dev/null +++ b/changelog.d/6277.misc @@ -0,0 +1 @@ +Remove redundant CLI parameters on CI's `flake8` step. \ No newline at end of file diff --git a/changelog.d/6278.bugfix b/changelog.d/6278.bugfix new file mode 100644 index 0000000000..c107270461 --- /dev/null +++ b/changelog.d/6278.bugfix @@ -0,0 +1 @@ +Fix exception when remote servers attempt to join a room that they're not allowed to join. diff --git a/changelog.d/6279.misc b/changelog.d/6279.misc new file mode 100644 index 0000000000..5f5144a9ee --- /dev/null +++ b/changelog.d/6279.misc @@ -0,0 +1 @@ +Port `federation_server.py` to async/await. diff --git a/changelog.d/6280.misc b/changelog.d/6280.misc new file mode 100644 index 0000000000..96a0eb21b2 --- /dev/null +++ b/changelog.d/6280.misc @@ -0,0 +1 @@ +Port receipt and read markers to async/wait. diff --git a/changelog.d/6284.bugfix b/changelog.d/6284.bugfix new file mode 100644 index 0000000000..cf15053d2d --- /dev/null +++ b/changelog.d/6284.bugfix @@ -0,0 +1 @@ +Prevent errors from appearing on Synapse startup if `git` is not installed. \ No newline at end of file diff --git a/changelog.d/6291.misc b/changelog.d/6291.misc new file mode 100644 index 0000000000..7b1bb4b679 --- /dev/null +++ b/changelog.d/6291.misc @@ -0,0 +1 @@ +Change cache descriptors to always return deferreds. diff --git a/changelog.d/6294.misc b/changelog.d/6294.misc new file mode 100644 index 0000000000..a3e6b8296e --- /dev/null +++ b/changelog.d/6294.misc @@ -0,0 +1 @@ +Split out state storage into separate data store. diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc new file mode 100644 index 0000000000..d4190730b2 --- /dev/null +++ b/changelog.d/6298.misc @@ -0,0 +1 @@ +Refactor EventContext for clarity. \ No newline at end of file diff --git a/changelog.d/6300.misc b/changelog.d/6300.misc new file mode 100644 index 0000000000..0b3d7a14a1 --- /dev/null +++ b/changelog.d/6300.misc @@ -0,0 +1 @@ +Move `persist_events` out from main data store. diff --git a/changelog.d/6304.misc b/changelog.d/6304.misc new file mode 100644 index 0000000000..20372b4f7c --- /dev/null +++ b/changelog.d/6304.misc @@ -0,0 +1 @@ +Update the version of black used to 19.10b0. diff --git a/changelog.d/6306.bugfix b/changelog.d/6306.bugfix new file mode 100644 index 0000000000..c7dcbcdce8 --- /dev/null +++ b/changelog.d/6306.bugfix @@ -0,0 +1 @@ +Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash. diff --git a/changelog.d/6307.bugfix b/changelog.d/6307.bugfix new file mode 100644 index 0000000000..f2917c5053 --- /dev/null +++ b/changelog.d/6307.bugfix @@ -0,0 +1 @@ +Fix `/purge_room` admin API. diff --git a/changelog.d/6312.misc b/changelog.d/6312.misc new file mode 100644 index 0000000000..55e3e1654d --- /dev/null +++ b/changelog.d/6312.misc @@ -0,0 +1 @@ +Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only. diff --git a/changelog.d/6313.bugfix b/changelog.d/6313.bugfix new file mode 100644 index 0000000000..f4d4a97f00 --- /dev/null +++ b/changelog.d/6313.bugfix @@ -0,0 +1 @@ +Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0. diff --git a/changelog.d/6314.misc b/changelog.d/6314.misc new file mode 100644 index 0000000000..2369760272 --- /dev/null +++ b/changelog.d/6314.misc @@ -0,0 +1 @@ +Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. \ No newline at end of file diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 6b22400a60..3bbbcfa1b4 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -78,7 +78,7 @@ class InputOutput(object): m = re.match("^join (\S+)$", line) if m: # The `sender` wants to join a room. - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("%s joining %s" % (self.user, room_name)) self.server.join_room(room_name, self.user, self.user) # self.print_line("OK.") @@ -105,7 +105,7 @@ class InputOutput(object): m = re.match("^backfill (\S+)$", line) if m: # we want to backfill a room - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("backfill %s" % room_name) self.server.backfill(room_name) return diff --git a/docker/start.py b/docker/start.py index e41ea20e70..6e1cb807a1 100755 --- a/docker/start.py +++ b/docker/start.py @@ -217,8 +217,9 @@ def main(args, environ): # backwards-compatibility generate-a-config-on-the-fly mode if "SYNAPSE_CONFIG_PATH" in environ: error( - "SYNAPSE_SERVER_NAME and SYNAPSE_CONFIG_PATH are mutually exclusive " - "except in `generate` or `migrate_config` mode." + "SYNAPSE_SERVER_NAME can only be combined with SYNAPSE_CONFIG_PATH " + "in `generate` or `migrate_config` mode. To start synapse using a " + "config file, unset the SYNAPSE_SERVER_NAME environment variable." ) config_path = "/compiled/homeserver.yaml" diff --git a/docs/CAPTCHA_SETUP.md b/docs/CAPTCHA_SETUP.md index 5f9057530b..331e5d059a 100644 --- a/docs/CAPTCHA_SETUP.md +++ b/docs/CAPTCHA_SETUP.md @@ -4,7 +4,7 @@ The captcha mechanism used is Google's ReCaptcha. This requires API keys from Go ## Getting keys -Requires a public/private key pair from: +Requires a site/secret key pair from: <https://developers.google.com/recaptcha/> @@ -15,8 +15,8 @@ Must be a reCAPTCHA v2 key using the "I'm not a robot" Checkbox option The keys are a config option on the home server config. If they are not visible, you can generate them via `--generate-config`. Set the following value: - recaptcha_public_key: YOUR_PUBLIC_KEY - recaptcha_private_key: YOUR_PRIVATE_KEY + recaptcha_public_key: YOUR_SITE_KEY + recaptcha_private_key: YOUR_SECRET_KEY In addition, you MUST enable captchas via: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6c81c0db75..d2f4aff826 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -955,7 +955,7 @@ uploads_path: "DATADIR/uploads" # If a delegate is specified, the config option public_baseurl must also be filled out. # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process # Users who register on this homeserver will automatically be joined diff --git a/mypy.ini b/mypy.ini index ffadaddc0b..1d77c0ecc8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,8 +1,11 @@ [mypy] -namespace_packages=True -plugins=mypy_zope:plugin -follow_imports=skip -mypy_path=stubs +namespace_packages = True +plugins = mypy_zope:plugin +follow_imports = normal +check_untyped_defs = True +show_error_codes = True +show_traceback = True +mypy_path = stubs [mypy-zope] ignore_missing_imports = True diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 02a2ca39e5..34c4854e11 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -7,7 +7,15 @@ set -e -isort -y -rc synapse tests scripts-dev scripts -flake8 synapse tests -python3 -m black synapse tests scripts-dev scripts +if [ $# -ge 1 ] +then + files=$* +else + files="synapse tests scripts-dev scripts" +fi + +echo "Linting these locations: $files" +isort -y -rc $files +flake8 $files +python3 -m black $files ./scripts-dev/config-lint.sh diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 10166583e1..27a1ad1e7e 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -25,8 +25,8 @@ from twisted.internet import defer, reactor from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.storage import DataStore +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -122,4 +122,3 @@ if __name__ == "__main__": )) reactor.run() - diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 12747c6024..b5b63933ab 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths): # check that the original exists original_file = src_paths.remote_media_filepath(origin_server, file_id) if not os.path.exists(original_file): - logger.warn( + logger.warning( "Original for %s/%s (%s) does not exist", origin_server, file_id, diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 54faed1e83..0d3321682c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -157,7 +157,7 @@ class Store( ) except self.database_engine.module.DatabaseError as e: if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) + logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) if i < N: i += 1 conn.rollback() @@ -432,7 +432,7 @@ class Porter(object): for row in rows: d = dict(zip(headers, row)) if "\0" in d['value']: - logger.warn('dropping search row %s', d) + logger.warning('dropping search row %s', d) else: rows_dict.append(d) @@ -647,7 +647,7 @@ class Porter(object): if isinstance(col, bytes): return bytearray(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn( + logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 53f3bb0fa8..5d0b7d2801 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -497,7 +497,7 @@ class Auth(object): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: - logger.warn("Unrecognised appservice access token.") + logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender return defer.succeed(service) diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index d877c77834..a01bac2997 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses): bind_addresses (list): Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: - logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") + logger.warning( + "Failed to listen on 0.0.0.0, continuing because listening on [::]" + ) else: raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 767b87d2db..02b900f382 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -94,7 +94,7 @@ class AppserviceServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -103,7 +103,7 @@ class AppserviceServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dbcc414c42..dadb487d5f 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index f20d810ece..d110599a35 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 1ef027a88c..418c086254 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 04fbb407af..139221ad34 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 9504bfbc70..e647459d0e 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index eb54f56853..8d28076d92 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -282,7 +282,7 @@ class SynapseHomeServer(HomeServer): reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -291,7 +291,7 @@ class SynapseHomeServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) def run_startup_checks(self, db_conn, database_engine): all_users_native = are_all_users_on_domain( @@ -565,11 +565,11 @@ def run(hs): "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) ) try: - yield hs.get_simple_http_client().put_json( + yield hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: - logger.warn("Error reporting stats: %s", e) + logger.warning("Error reporting stats: %s", e) def performance_stats_init(): try: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 6bc7202f33..2c6dd3ef02 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index d84732ee3c..01a5ffc363 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -114,7 +114,7 @@ class PusherServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -123,7 +123,7 @@ class PusherServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 6a7e2fa707..b14da09f47 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index a5d6dc7915..6cb100319f 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 33b3579425..aea3985a5f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -94,7 +94,9 @@ class ApplicationService(object): ip_range_whitelist=None, ): self.token = token - self.url = url + self.url = ( + url.rstrip("/") if isinstance(url, str) else None + ) # url must not end with a slash self.hs_token = hs_token self.sender = sender self.server_name = hostname diff --git a/synapse/config/key.py b/synapse/config/key.py index ec5d430afb..52ff1b2621 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -125,7 +125,7 @@ class KeyConfig(Config): # if neither trusted_key_servers nor perspectives are given, use the default. if "perspectives" not in config and "trusted_key_servers" not in config: - logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) + logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) key_servers = [{"server_name": "matrix.org"}] else: key_servers = config.get("trusted_key_servers", []) @@ -156,7 +156,7 @@ class KeyConfig(Config): if not self.macaroon_secret_key: # Unfortunately, there are people out there that don't have this # set. Lets just be "nice" and derive one from their secret key. - logger.warn("Config is missing macaroon_secret_key") + logger.warning("Config is missing macaroon_secret_key") seed = bytes(self.signing_key[0]) self.macaroon_secret_key = hashlib.sha256(seed).digest() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index be92e33f93..75bb904718 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None): logger = logging.getLogger("") if not log_config: - logger.warn("Reloaded a blank config?") + logger.warning("Reloaded a blank config?") logging.config.dictConfig(log_config) @@ -234,8 +234,8 @@ def setup_logging( # make sure that the first thing we log is a thing we can grep backwards # for - logging.warn("***** STARTING SERVER *****") - logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) + logging.warning("***** STARTING SERVER *****") + logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) return logger diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ab41623b2b..1f6dac69da 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -300,7 +300,7 @@ class RegistrationConfig(Config): # If a delegate is specified, the config option public_baseurl must also be filled out. # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process # Users who register on this homeserver will automatically be joined diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e7b722547b..ec3243b27b 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if auth_events is None: # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) + logger.warning("Trusting event: %s", event.event_id) return if event.type == EventTypes.Create: diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 27cd8a63ff..a269de5482 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -37,9 +37,6 @@ class EventContext: delta_ids (dict[(str, str), str]): Delta from ``prev_group``. (type, state_key) -> event_id. ``None`` for an outlier. - prev_state_events (?): XXX: is this ever set to anything other than - the empty list? - app_service: FIXME _current_state_ids (dict[(str, str), str]|None): @@ -51,36 +48,16 @@ class EventContext: The current state map excluding the current event. None if outlier or we haven't fetched the state from DB yet. (type, state_key) -> event_id - - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet - - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. - - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. - - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. """ state_group = attr.ib(default=None) rejected = attr.ib(default=False) prev_group = attr.ib(default=None) delta_ids = attr.ib(default=None) - prev_state_events = attr.ib(default=attr.Factory(list)) app_service = attr.ib(default=None) - _current_state_ids = attr.ib(default=None) _prev_state_ids = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _current_state_ids = attr.ib(default=None) @staticmethod def with_state( @@ -90,7 +67,6 @@ class EventContext: current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, state_group=state_group, - fetching_state_deferred=defer.succeed(None), prev_group=prev_group, delta_ids=delta_ids, ) @@ -125,7 +101,6 @@ class EventContext: "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None, } @@ -141,7 +116,7 @@ class EventContext: Returns: EventContext """ - context = EventContext( + context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. prev_state_id=input["prev_state_id"], @@ -151,7 +126,6 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], - prev_state_events=input["prev_state_events"], ) app_service_id = input["app_service_id"] @@ -170,14 +144,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._current_state_ids @defer.inlineCallbacks @@ -190,14 +157,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._prev_state_ids def get_cached_current_state_ids(self): @@ -211,6 +171,44 @@ class EventContext: return self._current_state_ids + def _ensure_fetched(self, store): + return defer.succeed(None) + + +@attr.s(slots=True) +class _AsyncEventContextImpl(EventContext): + """ + An implementation of EventContext which fetches _current_state_ids and + _prev_state_ids from the database on demand. + + Attributes: + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self, store): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store + ) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids @@ -228,27 +226,6 @@ class EventContext: else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 129771f183..5a907718d6 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + +from synapse.spam_checker_api import SpamCheckerApi + class SpamChecker(object): def __init__(self, hs): @@ -26,7 +31,14 @@ class SpamChecker(object): pass if module is not None: - self.spam_checker = module(config=config) + # Older spam checkers don't accept the `api` argument, so we + # try and detect support. + spam_args = inspect.getfullargspec(module) + if "api" in spam_args.args: + api = SpamCheckerApi(hs) + self.spam_checker = module(config=config, api=api) + else: + self.spam_checker = module(config=config) def check_event_for_spam(self, event): """Checks if a given event is considered "spammy" by this server. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 223aace0d9..0e22183280 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -102,7 +102,7 @@ class FederationBase(object): pass if not res: - logger.warn( + logger.warning( "Failed to find copy of %s with valid signature", pdu.event_id ) @@ -173,7 +173,7 @@ class FederationBase(object): return redacted_event if self.spam_checker.check_event_for_spam(pdu): - logger.warn( + logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json(), @@ -185,7 +185,7 @@ class FederationBase(object): def errback(failure, pdu): failure.trap(SynapseError) with PreserveLoggingContext(ctx): - logger.warn( + logger.warning( "Signature check failed for %s: %s", pdu.event_id, failure.getErrorMessage(), diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f5c1632916..545d719652 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -522,12 +522,12 @@ class FederationClient(FederationBase): res = yield callback(destination) return res except InvalidResponseError as e: - logger.warn("Failed to %s via %s: %s", description, destination, e) + logger.warning("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: - logger.warn( + logger.warning( "Failed to %s via %s: %i %s", description, destination, @@ -535,7 +535,9 @@ class FederationClient(FederationBase): e.args[0], ) except Exception: - logger.warn("Failed to %s via %s", description, destination, exc_info=1) + logger.warning( + "Failed to %s via %s", description, destination, exc_info=1 + ) raise SynapseError(502, "Failed to %s via any server" % (description,)) @@ -553,7 +555,7 @@ class FederationClient(FederationBase): Note that this does not append any events to any graphs. Args: - destinations (str): Candidate homeservers which are probably + destinations (Iterable[str]): Candidate homeservers which are probably participating in the room. room_id (str): The room in which the event will happen. user_id (str): The user whose membership is being evented. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5fc7c1d67b..d942d77a72 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -151,7 +145,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - # Reject if PDU count > 50 and EDU count > 100 + # Reject if PDU count > 50 or EDU count > 100 if len(transaction.pdus) > 50 or ( hasattr(transaction, "edus") and len(transaction.edus) > 100 ): @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,13 +215,12 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn("Ignoring PDUs for room %s from banned server", room_id) + logger.warning("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -237,10 +230,10 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) + logger.warning("Error handling PDU %s: %s", event_id, e) pdu_results[event_id] = {"error": str(e)} except Exception as e: f = failure.Failure() @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,60 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: - logger.warn("Room version %s not in %s", room_version, supported_versions) + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,28 +354,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = yield self._check_sigs_and_hash(room_version, pdu) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -402,48 +386,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - yield self.handler.on_send_leave_request(origin, pdu) + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -462,22 +442,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -503,16 +483,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -536,14 +514,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -553,7 +529,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -586,8 +562,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -640,37 +615,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -680,13 +652,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -709,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event): # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warn("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignorning non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -723,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event): # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warn("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignorning non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -733,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event): # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warn("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignorning non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -747,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn( + logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False @@ -799,15 +771,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: - logger.warn("No handler registered for EDU type %s", edu_type) + logger.warning("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: @@ -816,7 +787,7 @@ class FederationHandlerRegistry(object): def on_query(self, query_type, args): handler = self.query_handlers.get(query_type) if not handler: - logger.warn("No handler registered for query type %s", query_type) + logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) @@ -840,7 +811,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -848,17 +819,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: - return super(ReplicationFederationHandlerRegistry, self).on_edu( + return await super(ReplicationFederationHandlerRegistry, self).on_edu( edu_type, origin, content ) - return self._send_edu(edu_type=edu_type, origin=origin, content=content) + return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - def on_query(self, query_type, args): + async def on_query(self, query_type, args): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) if handler: - return handler(args) + return await handler(args) - return self._get_query_client(query_type=query_type, args=args) + return await self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 454456a52d..ced4925a98 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -36,6 +36,8 @@ from six import iteritems from sortedcontainers import SortedDict +from twisted.internet import defer + from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure @@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object): receipt (synapse.types.ReadReceipt): """ # nothing to do here: the replication listener will handle it. - pass + return defer.succeed(None) def send_presence(self, states): """As per FederationSender diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index cc75c39476..a5b36b1827 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -192,15 +192,16 @@ class PerDestinationQueue(object): # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = ( - yield self._get_device_update_edus(limit) + device_update_edus, dev_list_id = yield self._get_device_update_edus( + limit ) limit -= len(device_update_edus) - to_device_edus, device_stream_id = ( - yield self._get_to_device_message_edus(limit) - ) + ( + to_device_edus, + device_stream_id, + ) = yield self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus @@ -359,20 +360,20 @@ class PerDestinationQueue(object): last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_devices_by_remote( + now_stream_id, results = yield self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type=edu_type, content=content, ) - for content in results + for (edu_type, content) in results ] - assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs" return (edus, now_stream_id) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5b6c79c51a..67b3e1ab6e 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -146,7 +146,7 @@ class TransactionManager(object): if code == 200: for e_id, r in response.get("pdus", {}).items(): if "error" in r: - logger.warn( + logger.warning( "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, @@ -155,7 +155,7 @@ class TransactionManager(object): ) else: for p in pdus: - logger.warn( + logger.warning( "TX [%s] {%s} Failed to send event %s", destination, txn_id, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 0f16f21c2d..d6c23f22bd 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes): sig = strip_quotes(param_dict["sig"]) return origin, key, sig except Exception as e: - logger.warn( + logger.warning( "Error parsing auth header '%s': %s", header_bytes.decode("ascii", "replace"), e, @@ -287,10 +287,12 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.warn("authenticate_request failed: missing authentication") + logger.warning( + "authenticate_request failed: missing authentication" + ) raise except Exception as e: - logger.warn("authenticate_request failed: %s", e) + logger.warning("authenticate_request failed: %s", e) raise request_tags = { diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dfd7ae041b..d950a8b246 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -181,7 +181,7 @@ class GroupAttestionRenewer(object): elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - logger.warn( + logger.warning( "Incorrectly trying to do attestations for user: %r in %r", user_id, group_id, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8f10b6adbb..29e8ffc295 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -488,7 +488,7 @@ class GroupsServerHandler(object): profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) except Exception as e: - logger.warn("Error getting profile for %s: %s", user_id, e) + logger.warning("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 38bc67191c..2d7e6df6e4 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -38,9 +38,10 @@ class AccountDataEventSource(object): {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} ) - account_data, room_account_data = ( - yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) - ) + ( + account_data, + room_account_data, + ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a87b58838..6407d56f8e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,6 +30,9 @@ class AdminHandler(BaseHandler): def __init__(self, hs): super(AdminHandler, self).__init__(hs) + self.storage = hs.get_storage() + self.state_store = self.storage.state + @defer.inlineCallbacks def get_whois(self, user): connections = [] @@ -205,7 +208,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.store, user_id, events) + events = yield filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -241,7 +244,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.store.get_state_for_event(event_id) + state = yield self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3e9b298154..fe62f78e67 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -73,7 +73,10 @@ class ApplicationServicesHandler(object): try: limit = 100 while True: - upper_bound, events = yield self.store.get_new_events_for_appservice( + ( + upper_bound, + events, + ) = yield self.store.get_new_events_for_appservice( self.current_max, limit ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 333eb30625..7a0f54ca24 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -525,7 +525,7 @@ class AuthHandler(BaseHandler): result = None if not user_infos: - logger.warn("Attempted to login as %s but they do not exist", user_id) + logger.warning("Attempted to login as %s but they do not exist", user_id) elif len(user_infos) == 1: # a single match (possibly not exact) result = user_infos.popitem() @@ -534,7 +534,7 @@ class AuthHandler(BaseHandler): result = (user_id, user_infos[user_id]) else: # multiple matches, none of them exact - logger.warn( + logger.warning( "Attempted to login as %s but it matches more than one user " "inexactly: %r", user_id, @@ -728,7 +728,7 @@ class AuthHandler(BaseHandler): result = yield self.validate_hash(password, password_hash) if not result: - logger.warn("Failed password login for user %s", user_id) + logger.warning("Failed password login for user %s", user_id) return None return user_id diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..26ef5e150c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() @trace @@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -458,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler): @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - return {"user_id": user_id, "stream_id": stream_id, "devices": devices} + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing" + ) + + return { + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + "master_key": master_key, + "self_signing_key": self_signing_key, + } @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -656,7 +668,7 @@ class DeviceListUpdater(object): except (NotRetryingDestination, RequestSendFailed, HttpResponseException): # TODO: Remember that we are now out of sync and try again # later - logger.warn("Failed to handle device list update for %s", user_id) + logger.warning("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -694,7 +706,7 @@ class DeviceListUpdater(object): # up on storing the total list of devices and only handle the # delta instead. if len(devices) > 1000: - logger.warn( + logger.warning( "Ignoring device list snapshot for %s as it has >1K devs (%d)", user_id, len(devices), diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0043cbea17..73b9e120f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -52,7 +52,7 @@ class DeviceMessageHandler(object): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): - logger.warn( + logger.warning( "Dropping device message from %r with spoofed sender %r", origin, sender_user_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 526379c6f7..c4632f8984 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler): ignore_backoff=True, ) except CodeMessageException as e: - logging.warn("Error retrieving alias") + logging.warning("Error retrieving alias") if e.code == 404: result = None else: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5ea54f60be..f09a0b73c8 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -36,6 +36,8 @@ from synapse.types import ( get_verify_key_from_cross_signing_key, ) from synapse.util import unwrapFirstError +from synapse.util.async_helpers import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -49,10 +51,19 @@ class E2eKeysHandler(object): self.is_mine = hs.is_mine self.clock = hs.get_clock() + self._edu_updater = SigningKeyEduUpdater(hs, self) + + federation_registry = hs.get_federation_registry() + + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "org.matrix.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - hs.get_federation_registry().register_query_handler( + federation_registry.register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -119,9 +130,10 @@ class E2eKeysHandler(object): else: query_list.append((user_id, None)) - user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache(query_list) - ) + ( + user_ids_not_in_cache, + remote_results, + ) = yield self.store.get_user_devices_from_cache(query_list) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) for device_id, device in iteritems(devices): @@ -207,13 +219,15 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys - for user_id, key in remote_result["master_keys"].items(): - if user_id in destination_query: - cross_signing_keys["master_keys"][user_id] = key + if "master_keys" in remote_result: + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master_keys"][user_id] = key - for user_id, key in remote_result["self_signing_keys"].items(): - if user_id in destination_query: - cross_signing_keys["self_signing_keys"][user_id] = key + if "self_signing_keys" in remote_result: + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing_keys"][user_id] = key except Exception as e: failure = _exception_to_failure(e) @@ -251,7 +265,7 @@ class E2eKeysHandler(object): Returns: defer.Deferred[dict[str, dict[str, dict]]]: map from - (master|self_signing|user_signing) -> user_id -> key + (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key """ master_keys = {} self_signing_keys = {} @@ -343,7 +357,16 @@ class E2eKeysHandler(object): """ device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) - return {"device_keys": res} + ret = {"device_keys": res} + + # add in the cross-signing keys + cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + device_keys_query, None + ) + + ret.update(cross_signing_keys) + + return ret @trace @defer.inlineCallbacks @@ -688,17 +711,21 @@ class E2eKeysHandler(object): try: # get our self-signing key to verify the signatures - _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "self_signing" - ) + ( + _, + self_signing_key_id, + self_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so # that we can check if a signature that was sent is a signature of # the master key or of a device - master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "master" - ) + ( + master_key, + _, + master_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what @@ -838,9 +865,11 @@ class E2eKeysHandler(object): try: # get our user-signing key to verify the signatures - user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "user_signing" - ) + ( + user_signing_key, + user_signing_key_id, + user_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -859,7 +888,11 @@ class E2eKeysHandler(object): try: # get the target user's master key, to make sure it matches # what was sent - master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + ( + master_key, + master_key_id, + _, + ) = yield self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -1047,3 +1080,100 @@ class SignatureListItem: target_user_id = attr.ib() target_device_id = attr.ib() signature = attr.ib() + + +class SigningKeyEduUpdater(object): + """Handles incoming signing key updates from federation and updates the DB""" + + def __init__(self, hs, e2e_keys_handler): + self.store = hs.get_datastore() + self.federation = hs.get_federation_client() + self.clock = hs.get_clock() + self.e2e_keys_handler = e2e_keys_handler + + self._remote_edu_linearizer = Linearizer(name="remote_signing_key") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="signing_key_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_signing_key_update(self, origin, edu_content): + """Called on incoming signing key update from federation. Responsible for + parsing the EDU and adding to pending updates list. + + Args: + origin (string): the server that sent the EDU + edu_content (dict): the contents of the EDU + """ + + user_id = edu_content.pop("user_id") + master_key = edu_content.pop("master_key", None) + self_signing_key = edu_content.pop("self_signing_key", None) + + if get_domain_from_id(user_id) != origin: + logger.warning("Got signing key update edu for %r from %r", user_id, origin) + return + + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return + + self._pending_updates.setdefault(user_id, []).append( + (master_key, self_signing_key) + ) + + yield self._handle_signing_key_updates(user_id) + + @defer.inlineCallbacks + def _handle_signing_key_updates(self, user_id): + """Actually handle pending updates. + + Args: + user_id (string): the user whose updates we are processing + """ + + device_handler = self.e2e_keys_handler.device_handler + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + device_ids = [] + + logger.info("pending updates: %r", pending_updates) + + for master_key, self_signing_key in pending_updates: + if master_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "master", master_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key( + self_signing_key + ) + device_ids.append(verify_key.version) + + yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5e748687e3..45fe13c62f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): + def __init__(self, hs): + super(EventHandler, self).__init__(hs) + self.storage = hs.get_storage() + @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -172,7 +176,7 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, user.to_string(), [event], is_peeking=is_peeking + self.storage, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 488058fe68..8cafcfdab0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( make_deferred_yieldable, @@ -109,6 +110,8 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -180,7 +183,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn( + logger.warning( "[%s %s] Received event failed sanity checks", room_id, event_id ) raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) @@ -301,7 +304,7 @@ class FederationHandler(BaseHandler): # following. if sent_to_us_directly: - logger.warn( + logger.warning( "[%s %s] Rejecting: failed to fetch %d prev events: %s", room_id, event_id, @@ -324,7 +327,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.store.get_state_groups_ids(room_id, seen) + ours = yield self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -350,10 +353,11 @@ class FederationHandler(BaseHandler): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ( + remote_state, + got_auth_chain, + ) = yield self.federation_client.get_state_for_room( + origin, room_id, p ) # we want the state *after* p; get_state_for_room returns the @@ -405,7 +409,7 @@ class FederationHandler(BaseHandler): state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: - logger.warn( + logger.warning( "[%s %s] Error attempting to resolve state at missing " "prev_events", room_id, @@ -518,7 +522,9 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) + logger.warning( + "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e + ) return logger.info( @@ -545,7 +551,7 @@ class FederationHandler(BaseHandler): yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: - logger.warn( + logger.warning( "[%s %s] Received prev_event %s failed history check.", room_id, event_id, @@ -888,7 +894,7 @@ class FederationHandler(BaseHandler): # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, + self.storage, self.server_name, list(extremities_events.values()), redact=False, @@ -1059,7 +1065,7 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn( + logger.warning( "Rejecting event %s which has %i prev_events", ev.event_id, len(ev.prev_event_ids()), @@ -1067,7 +1073,7 @@ class FederationHandler(BaseHandler): raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn( + logger.warning( "Rejecting event %s which has %i auth_events", ev.event_id, len(ev.auth_event_ids()), @@ -1101,7 +1107,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_invite_join(self, target_hosts, room_id, joinee, content): """ Attempts to join the `joinee` to the room `room_id` via the - server `target_host`. + servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial event that we can fill out and sign. This is then sent to the @@ -1110,6 +1116,15 @@ class FederationHandler(BaseHandler): We suspend processing of any received events from this room until we have finished processing the join. + + Args: + target_hosts (Iterable[str]): List of servers to attempt to join the room with. + + room_id (str): The ID of the room to join. + + joinee (str): The User ID of the joining user. + + content (dict): The event content to use for the join event. """ logger.debug("Joining %s to %s", joinee, room_id) @@ -1169,6 +1184,22 @@ class FederationHandler(BaseHandler): yield self._persist_auth_tree(origin, auth_chain, state, event) + # Check whether this room is the result of an upgrade of a room we already know + # about. If so, migrate over user information + predecessor = yield self.store.get_room_predecessor(room_id) + if not predecessor: + return + old_room_id = predecessor["room_id"] + logger.debug( + "Found predecessor for %s during remote join: %s", room_id, old_room_id + ) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + yield member_handler.transfer_room_state_on_room_upgrade( + old_room_id, room_id + ) + logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] @@ -1203,7 +1234,7 @@ class FederationHandler(BaseHandler): with nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: - logger.warn( + logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e ) @@ -1250,7 +1281,7 @@ class FederationHandler(BaseHandler): builder=builder ) except AuthError as e: - logger.warn("Failed to create join %r because %s", event, e) + logger.warning("Failed to create join to %s because %s", room_id, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( @@ -1494,7 +1525,7 @@ class FederationHandler(BaseHandler): room_version, event, context, do_sig_check=False ) except AuthError as e: - logger.warn("Failed to create new leave %r because %s", event, e) + logger.warning("Failed to create new leave %r because %s", event, e) raise e return event @@ -1549,7 +1580,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1578,7 +1609,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1606,7 +1637,7 @@ class FederationHandler(BaseHandler): events = yield self.store.get_backfill_events(room_id, pdu_list, limit) - events = yield filter_events_for_server(self.store, origin, events) + events = yield filter_events_for_server(self.storage, origin, events) return events @@ -1636,7 +1667,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.store, origin, [event]) + events = yield filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -1788,7 +1819,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn("Rejecting %s because %s", e.event_id, err.msg) + logger.warning("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1841,12 +1872,7 @@ class FederationHandler(BaseHandler): if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) - - context.rejected = RejectedReason.AUTH_ERROR + context = yield self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: yield self._check_for_soft_fail(event, state, backfilled) @@ -1902,7 +1928,7 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = yield self.store.get_state_groups( + state_sets = yield self.state_store.get_state_groups( event.room_id, extrem_ids ) state_sets = list(state_sets.values()) @@ -1938,7 +1964,7 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn("Soft-failing %r because %s", event, e) + logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks @@ -1993,7 +2019,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events + self.storage, origin, missing_events ) return missing_events @@ -2015,12 +2041,12 @@ class FederationHandler(BaseHandler): Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) try: - yield self._update_auth_events_and_context_for_auth( + context = yield self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2037,8 +2063,10 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) - raise e + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + + return context @defer.inlineCallbacks def _update_auth_events_and_context_for_auth( @@ -2062,7 +2090,7 @@ class FederationHandler(BaseHandler): auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) @@ -2101,7 +2129,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] @@ -2142,7 +2170,7 @@ class FederationHandler(BaseHandler): if event.internal_metadata.is_outlier(): logger.info("Skipping auth_event fetch for outlier") - return + return context # FIXME: Assumes we have and stored all the state for all the # prev_events @@ -2151,7 +2179,7 @@ class FederationHandler(BaseHandler): ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2198,10 +2226,12 @@ class FederationHandler(BaseHandler): auth_events.update(new_state) - yield self._update_context_for_auth_events( + context = yield self._update_context_for_auth_events( event, context, auth_events, event_key ) + return context + @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2210,14 +2240,16 @@ class FederationHandler(BaseHandler): Args: event (Event): The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context (synapse.events.snapshot.EventContext): initial event context auth_events (dict[(str, str)->str]): Events to update in the event context. event_key ((str, str)): (type, state_key) for the current event. this will not be included in the current_state in the context. + + Returns: + Deferred[EventContext]: new event context """ state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key @@ -2234,7 +2266,7 @@ class FederationHandler(BaseHandler): # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -2242,7 +2274,7 @@ class FederationHandler(BaseHandler): current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, @@ -2431,10 +2463,12 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying new third party invite %r because %s", event, e) + logger.warning("Denying new third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) + + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: @@ -2487,7 +2521,7 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying third party invite %r because %s", event, e) + logger.warning("Denying third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) @@ -2495,6 +2529,7 @@ class FederationHandler(BaseHandler): # though the sender isn't a local user. event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender) + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) @@ -2664,7 +2699,7 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) else: - max_stream_id = yield self.store.persist_events( + max_stream_id = yield self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 46eb9ee88b..92fecbfc44 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -392,7 +392,7 @@ class GroupsLocalHandler(object): try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: - logger.warn("No profile for user %s: %s", user_id, e) + logger.warning("No profile for user %s: %s", user_id, e) user_profile = {} return {"state": "invite", "user_profile": user_profile} diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ba99ddf76d..000fbf090f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler): changed = False if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) - logger.warn("Received %d response while unbinding threepid", e.code) + logger.warning("Received %d response while unbinding threepid", e.code) else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") @@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " @@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f991efeee3..81dce96f4b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state def snapshot_all_rooms( self, @@ -126,8 +128,8 @@ class InitialSyncHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + account_data, account_data_by_room = yield self.store.get_account_data_for_user( + user_id ) public_room_ids = yield self.store.get_public_room_ids() @@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, [event.event_id] + self.state_store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client(self.store, user_id, messages) + messages = yield filter_events_for_client( + self.storage, user_id, messages + ) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.store.get_state_for_events([member_event_id]) + room_state = yield self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace("room_key", token) @@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f8cce8ffe..d682dc2b7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,6 +59,8 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @@ -74,15 +76,16 @@ class MessageHandler(object): Raises: SynapseError if something went wrong. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -135,12 +138,12 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events + self.storage, user_id, last_events ) event = last_events[0] if visible_events: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -151,9 +154,10 @@ class MessageHandler(object): % (user_id, room_id, at_token), ) else: - membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable(room_id, user_id) - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( @@ -161,7 +165,7 @@ class MessageHandler(object): ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] @@ -234,6 +238,7 @@ class EventCreationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -687,7 +692,7 @@ class EventCreationHandler(object): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as err: - logger.warn("Denying new event %r because %s", event, err) + logger.warning("Denying new event %r because %s", event, err) raise err # Ensure that we can round trip before trying to persist in db @@ -868,7 +873,7 @@ class EventCreationHandler(object): if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - (event_stream_id, max_stream_id) = yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event, context=context ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..97f15a1c32 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -69,6 +69,8 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -210,9 +212,10 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") with (yield self.pagination_lock.read(room_id)): - membership, member_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + member_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -255,7 +258,7 @@ class PaginationHandler(object): events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, user_id, events, is_peeking=(member_event_id is None) + self.storage, user_id, events, is_peeking=(member_event_id is None) ) if not events: @@ -274,7 +277,7 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) @@ -295,10 +298,8 @@ class PaginationHandler(object): } if state: - chunk["state"] = ( - yield self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event - ) + chunk["state"] = yield self._event_serializer.serialize_events( + state, time_now, as_client_event=as_client_event ) return chunk diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 8690f69d45..22e0a04da4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler): ratelimit=False, # Try to hide that these events aren't atomic. ) except Exception as e: - logger.warn( + logger.warning( "Failed to update join event for room %s - %s", room_id, str(e) ) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 3e4d8c93a4..e3b528d271 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler): self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def received_client_read_marker(self, room_id, user_id, event_id): + async def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. @@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler): the read marker has changed. """ - with (yield self.read_marker_linearizer.queue((room_id, user_id))): - existing_read_marker = yield self.store.get_account_data_for_room_and_type( + with await self.read_marker_linearizer.queue((room_id, user_id)): + existing_read_marker = await self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read" ) @@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream - should_update = yield self.store.is_event_after( + should_update = await self.store.is_event_after( event_id, existing_read_marker["event_id"] ) if should_update: content = {"event_id": event_id} - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6854c751a6..9283c039e3 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - @defer.inlineCallbacks - def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler): ) ) - yield self._handle_new_receipts(receipts) + await self._handle_new_receipts(receipts) - @defer.inlineCallbacks - def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ min_batch_id = None max_batch_id = None for receipt in receipts: - res = yield self.store.insert_receipt( + res = await self.store.insert_receipt( receipt.room_id, receipt.receipt_type, receipt.user_id, @@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids + await maybe_awaitable( + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) ) return True - @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler): data={"ts": int(self.clock.time_msec())}, ) - is_new = yield self._handle_new_receipts([receipt]) + is_new = await self._handle_new_receipts([receipt]) if not is_new: return - yield self.federation.send_read_receipt(receipt) - - @defer.inlineCallbacks - def get_receipts_for_room(self, room_id, to_key): - """Gets all receipts for a room, upto the given key. - """ - result = yield self.store.get_linearized_receipts_for_room( - room_id, to_key=to_key - ) - - if not result: - return [] - - return result + await self.federation.send_read_receipt(receipt) class ReceiptEventSource(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 53410f120b..cff6b0d375 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = ( - yield room_member_handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_alias ) room_id = room_id.to_string() else: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..e92b2eafd5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler): old_room_id, new_version, # args for _upgrade_room ) + return ret @defer.inlineCallbacks @@ -147,21 +148,22 @@ class RoomCreationHandler(BaseHandler): # we create and auth the tombstone event before properly creating the new # room, to check our user has perms in the old room. - tombstone_event, tombstone_context = ( - yield self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, + ( + tombstone_event, + tombstone_context, + ) = yield self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, }, - token_id=requester.access_token_id, - ) + }, + token_id=requester.access_token_id, ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( @@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler): requester, old_room_id, new_room_id, old_room_state ) - # and finally, shut down the PLs in the old room, and update them in the new + # Copy over user push rules, tags and migrate room directory state + yield self.room_member_handler.transfer_room_state_on_room_upgrade( + old_room_id, new_room_id + ) + + # finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state @@ -822,6 +829,8 @@ class RoomContextHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, event_filter): @@ -848,7 +857,7 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, user.to_string(), events, is_peeking=is_peeking + self.storage, user.to_string(), events, is_peeking=is_peeking ) event = yield self.store.get_event( @@ -890,7 +899,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.store.get_state_for_events( + state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) @@ -922,7 +931,7 @@ class RoomEventSource(object): from_token = RoomStreamToken.parse(from_key) if from_token.topological: - logger.warn("Stream has topological part!!!! %r", from_key) + logger.warning("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) app_service = self.store.get_app_service_by_user_id(user.to_string()) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 380e2fad5e..06d09c2947 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -203,10 +203,6 @@ class RoomMemberHandler(object): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - # Copy over user state if we're joining an upgraded room - yield self.copy_user_state_if_room_upgrade( - room_id, requester.user.to_string() - ) yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -455,11 +451,6 @@ class RoomMemberHandler(object): requester, remote_room_hosts, room_id, target, content ) - # Copy over user state if this is a join on an remote upgraded room - yield self.copy_user_state_if_room_upgrade( - room_id, requester.user.to_string() - ) - return remote_join_response elif effective_membership_state == Membership.LEAVE: @@ -498,36 +489,72 @@ class RoomMemberHandler(object): return res @defer.inlineCallbacks - def copy_user_state_if_room_upgrade(self, new_room_id, user_id): - """Copy user-specific information when they join a new room if that new room is the + def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + """Upon our server becoming aware of an upgraded room, either by upgrading a room + ourselves or joining one, we can transfer over information from the previous room. + + Copies user state (tags/push rules) for every local user that was in the old room, as + well as migrating the room directory state. + + Args: + old_room_id (str): The ID of the old room + + room_id (str): The ID of the new room + + Returns: + Deferred + """ + # Find all local users that were in the old room and copy over each user's state + users = yield self.store.get_users_in_room(old_room_id) + yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) + + # Add new room to the room directory if the old room was there + # Remove old room from the room directory + old_room = yield self.store.get_room(old_room_id) + if old_room and old_room["is_public"]: + yield self.store.set_room_is_public(old_room_id, False) + yield self.store.set_room_is_public(room_id, True) + + @defer.inlineCallbacks + def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + """Copy user-specific information when they join a new room when that new room is the result of a room upgrade Args: - new_room_id (str): The ID of the room the user is joining - user_id (str): The ID of the user + old_room_id (str): The ID of upgraded room + new_room_id (str): The ID of the new room + user_ids (Iterable[str]): User IDs to copy state for Returns: Deferred """ - # Check if the new room is an upgraded room - predecessor = yield self.store.get_room_predecessor(new_room_id) - if not predecessor: - return logger.debug( - "Found predecessor for %s: %s. Copying over room tags and push " "rules", + "Copying over room tags and push rules from %s to %s for users %s", + old_room_id, new_room_id, - predecessor, + user_ids, ) - # It is an upgraded room. Copy over old tags - yield self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], new_room_id, user_id - ) - # Copy over push rules - yield self.store.copy_push_rules_from_room_to_room_for_user( - predecessor["room_id"], new_room_id, user_id - ) + for user_id in user_ids: + try: + # It is an upgraded room. Copy over old tags + yield self.copy_room_tags_and_direct_to_room( + old_room_id, new_room_id, user_id + ) + # Copy over push rules + yield self.store.copy_push_rules_from_room_to_room_for_user( + old_room_id, new_room_id, user_id + ) + except Exception: + logger.exception( + "Error copying tags and/or push rules from rooms %s to %s for user %s. " + "Skipping...", + old_room_id, + new_room_id, + user_id, + ) + continue @defer.inlineCallbacks def send_membership_event(self, requester, event, context, ratelimit=True): @@ -759,22 +786,25 @@ class RoomMemberHandler(object): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_keys, fallback_public_key, display_name = ( - yield self.identity_handler.ask_id_server_for_third_party_invite( - requester=requester, - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url, - id_access_token=id_access_token, - ) + ( + token, + public_keys, + fallback_public_key, + display_name, + ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + requester=requester, + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url, + id_access_token=id_access_token, ) yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd5e90bacb..56ed262a1f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -35,6 +35,8 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -221,7 +223,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -271,7 +273,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) room_events.extend(events) @@ -340,11 +342,11 @@ class SearchHandler(BaseHandler): ) res["events_before"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_before"] + self.storage, user.to_string(), res["events_before"] ) res["events_after"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_after"] + self.storage, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( @@ -372,7 +374,7 @@ class SearchHandler(BaseHandler): [(EventTypes.Member, sender) for sender in senders] ) - state = yield self.store.get_state_for_event( + state = yield self.state_store.get_state_for_event( last_event_id, state_filter ) @@ -394,15 +396,11 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = ( - yield self._event_serializer.serialize_events( - context["events_before"], time_now - ) + context["events_before"] = yield self._event_serializer.serialize_events( + context["events_before"], time_now ) - context["events_after"] = ( - yield self._event_serializer.serialize_events( - context["events_after"], time_now - ) + context["events_after"] = yield self._event_serializer.serialize_events( + context["events_after"], time_now ) state_results = {} diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 26bc276692..7f7d56390e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler): user_deltas = {} # Then count deltas for total_events and total_event_bytes. - room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + ( + room_count, + user_count, + ) = yield self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d99160e9d7..b536d410e5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -230,6 +230,8 @@ class SyncHandler(object): self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() self.auth = hs.get_auth() + self.storage = hs.get_storage() + self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -417,7 +419,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -470,7 +472,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) loaded_recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -509,7 +511,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -580,7 +582,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -757,11 +759,11 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) @@ -781,7 +783,7 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.store.get_state_ids_for_event( + state_at_timeline_start = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: @@ -810,7 +812,7 @@ class SyncHandler(object): ) if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: @@ -841,7 +843,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -1204,10 +1206,11 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_updated_account_data_for_user( + user_id, since_token.account_data_key ) push_rules_changed = yield self.store.have_push_rules_changed_for_user( @@ -1219,9 +1222,10 @@ class SyncHandler(object): sync_config.user ) else: - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(sync_config.user.to_string()) - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 29aa1e5aaf..8363d887a9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key diff --git a/synapse/http/client.py b/synapse/http/client.py index cdf828a4ff..d4c285445e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,6 +45,7 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred @@ -183,7 +184,15 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + hs, + treq_args={}, + ip_whitelist=None, + ip_blacklist=None, + http_proxy=None, + https_proxy=None, + ): """ Args: hs (synapse.server.HomeServer) @@ -192,6 +201,8 @@ class SimpleHttpClient(object): we may not request. ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. + http_proxy (bytes): proxy server to use for http connections. host[:port] + https_proxy (bytes): proxy server to use for https connections. host[:port] """ self.hs = hs @@ -236,11 +247,13 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = Agent( + self.agent = ProxyAgent( self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, + http_proxy=http_proxy, + https_proxy=https_proxy, ) if self._ip_blacklist: @@ -535,7 +548,7 @@ class SimpleHttpClient(object): b"Content-Length" in resp_headers and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size,), @@ -543,7 +556,7 @@ class SimpleHttpClient(object): ) if response.code > 299: - logger.warn("Got %d when downloading %s" % (response.code, url)) + logger.warning("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py new file mode 100644 index 0000000000..be7b2ceb8e --- /dev/null +++ b/synapse/http/connectproxyclient.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from zope.interface import implementer + +from twisted.internet import defer, protocol +from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.protocol import connectionDone +from twisted.web import http + +logger = logging.getLogger(__name__) + + +class ProxyConnectError(ConnectError): + pass + + +@implementer(IStreamClientEndpoint) +class HTTPConnectProxyEndpoint(object): + """An Endpoint implementation which will send a CONNECT request to an http proxy + + Wraps an existing HostnameEndpoint for the proxy. + + When we get the connect() request from the connection pool (via the TLS wrapper), + we'll first connect to the proxy endpoint with a ProtocolFactory which will make the + CONNECT request. Once that completes, we invoke the protocolFactory which was passed + in. + + Args: + reactor: the Twisted reactor to use for the connection + proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the + proxy + host (bytes): hostname that we want to CONNECT to + port (int): port that we want to connect to + """ + + def __init__(self, reactor, proxy_endpoint, host, port): + self._reactor = reactor + self._proxy_endpoint = proxy_endpoint + self._host = host + self._port = port + + def __repr__(self): + return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) + + def connect(self, protocolFactory): + f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + d = self._proxy_endpoint.connect(f) + # once the tcp socket connects successfully, we need to wait for the + # CONNECT to complete. + d.addCallback(lambda conn: f.on_connection) + return d + + +class HTTPProxiedClientFactory(protocol.ClientFactory): + """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect. + + Once the CONNECT completes, invokes the original ClientFactory to build the + HTTP Protocol object and run the rest of the connection. + + Args: + dst_host (bytes): hostname that we want to CONNECT to + dst_port (int): port that we want to connect to + wrapped_factory (protocol.ClientFactory): The original Factory + """ + + def __init__(self, dst_host, dst_port, wrapped_factory): + self.dst_host = dst_host + self.dst_port = dst_port + self.wrapped_factory = wrapped_factory + self.on_connection = defer.Deferred() + + def startedConnecting(self, connector): + return self.wrapped_factory.startedConnecting(connector) + + def buildProtocol(self, addr): + wrapped_protocol = self.wrapped_factory.buildProtocol(addr) + + return HTTPConnectProtocol( + self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + ) + + def clientConnectionFailed(self, connector, reason): + logger.debug("Connection to proxy failed: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + logger.debug("Connection to proxy lost: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionLost(connector, reason) + + +class HTTPConnectProtocol(protocol.Protocol): + """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect + + Args: + host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + to put in the CONNECT request + + port (int): The original HTTP(s) port to put in the CONNECT request + + wrapped_protocol (interfaces.IProtocol): the original protocol (probably + HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + + connected_deferred (Deferred): a Deferred which will be callbacked with + wrapped_protocol when the CONNECT completes + """ + + def __init__(self, host, port, wrapped_protocol, connected_deferred): + self.host = host + self.port = port + self.wrapped_protocol = wrapped_protocol + self.connected_deferred = connected_deferred + self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.http_setup_client.on_connected.addCallback(self.proxyConnected) + + def connectionMade(self): + self.http_setup_client.makeConnection(self.transport) + + def connectionLost(self, reason=connectionDone): + if self.wrapped_protocol.connected: + self.wrapped_protocol.connectionLost(reason) + + self.http_setup_client.connectionLost(reason) + + if not self.connected_deferred.called: + self.connected_deferred.errback(reason) + + def proxyConnected(self, _): + self.wrapped_protocol.makeConnection(self.transport) + + self.connected_deferred.callback(self.wrapped_protocol) + + # Get any pending data from the http buf and forward it to the original protocol + buf = self.http_setup_client.clearLineBuffer() + if buf: + self.wrapped_protocol.dataReceived(buf) + + def dataReceived(self, data): + # if we've set up the HTTP protocol, we can send the data there + if self.wrapped_protocol.connected: + return self.wrapped_protocol.dataReceived(data) + + # otherwise, we must still be setting up the connection: send the data to the + # setup client + return self.http_setup_client.dataReceived(data) + + +class HTTPConnectSetupClient(http.HTTPClient): + """HTTPClient protocol to send a CONNECT message for proxies and read the response. + + Args: + host (bytes): The hostname to send in the CONNECT message + port (int): The port to send in the CONNECT message + """ + + def __init__(self, host, port): + self.host = host + self.port = port + self.on_connected = defer.Deferred() + + def connectionMade(self): + logger.debug("Connected to proxy, sending CONNECT") + self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + self.endHeaders() + + def handleStatus(self, version, status, message): + logger.debug("Got Status: %s %s %s", status, message, version) + if status != b"200": + raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + + def handleEndHeaders(self): + logger.debug("End Headers") + self.on_connected.callback(None) + + def handleResponse(self, body): + pass diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 3fe4ffb9e5..021b233a7d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -148,7 +148,7 @@ class SrvResolver(object): # Try something in the cache, else rereaise cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( + logger.warning( "Failed to resolve %r, falling back to cache. %r", service_name, e ) return list(cache_entry) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3f7c93ffcb..691380abda 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): body = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, @@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object): except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. - logger.warn( + logger.warning( "{%s} [%s] Failed to get error response: %s %s: %s", request.txn_id, request.destination, @@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object): break except RequestSendFailed as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object): raise except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object): d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py new file mode 100644 index 0000000000..332da02a8d --- /dev/null +++ b/synapse/http/proxyagent.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.python.failure import Failure +from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.error import SchemeNotSupported +from twisted.web.iweb import IAgent + +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint + +logger = logging.getLogger(__name__) + +_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") + + +@implementer(IAgent) +class ProxyAgent(_AgentBase): + """An Agent implementation which will use an HTTP proxy if one was requested + + Args: + reactor: twisted reactor to place outgoing + connections. + + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + verification parameters of OpenSSL. The default is to use a + `BrowserLikePolicyForHTTPS`, so unless you have special + requirements you can leave this as-is. + + connectTimeout (float): The amount of time that this Agent will wait + for the peer to accept a connection. + + bindAddress (bytes): The local address for client sockets to bind to. + + pool (HTTPConnectionPool|None): connection pool to be used. If None, a + non-persistent pool instance will be created. + """ + + def __init__( + self, + reactor, + contextFactory=BrowserLikePolicyForHTTPS(), + connectTimeout=None, + bindAddress=None, + pool=None, + http_proxy=None, + https_proxy=None, + ): + _AgentBase.__init__(self, reactor, pool) + + self._endpoint_kwargs = {} + if connectTimeout is not None: + self._endpoint_kwargs["timeout"] = connectTimeout + if bindAddress is not None: + self._endpoint_kwargs["bindAddress"] = bindAddress + + self.http_proxy_endpoint = _http_proxy_endpoint( + http_proxy, reactor, **self._endpoint_kwargs + ) + + self.https_proxy_endpoint = _http_proxy_endpoint( + https_proxy, reactor, **self._endpoint_kwargs + ) + + self._policy_for_https = contextFactory + self._reactor = reactor + + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Issue a request to the server indicated by the given uri. + + Supports `http` and `https` schemes. + + An existing connection from the connection pool may be used or a new one may be + created. + + See also: twisted.web.iweb.IAgent.request + + Args: + method (bytes): The request method to use, such as `GET`, `POST`, etc + + uri (bytes): The location of the resource to request. + + headers (Headers|None): Extra headers to send with the request + + bodyProducer (IBodyProducer|None): An object which can generate bytes to + make up the body of this request (for example, the properly encoded + contents of a file for a file upload). Or, None if the request is to + have no body. + + Returns: + Deferred[IResponse]: completes when the header of the response has + been received (regardless of the response status code). + """ + uri = uri.strip() + if not _VALID_URI.match(uri): + raise ValueError("Invalid URI {!r}".format(uri)) + + parsed_uri = URI.fromBytes(uri) + pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + request_path = parsed_uri.originForm + + if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + # Cache *all* connections under the same key, since we are only + # connecting to a single destination, the proxy: + pool_key = ("http-proxy", self.http_proxy_endpoint) + endpoint = self.http_proxy_endpoint + request_path = uri + elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self.https_proxy_endpoint, + parsed_uri.host, + parsed_uri.port, + ) + else: + # not using a proxy + endpoint = HostnameEndpoint( + self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs + ) + + logger.debug("Requesting %s via %s", uri, endpoint) + + if parsed_uri.scheme == b"https": + tls_connection_creator = self._policy_for_https.creatorForNetloc( + parsed_uri.host, parsed_uri.port + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + elif parsed_uri.scheme == b"http": + pass + else: + return defer.fail( + Failure( + SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,)) + ) + ) + + return self._requestWithEndpoint( + pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path + ) + + +def _http_proxy_endpoint(proxy, reactor, **kwargs): + """Parses an http proxy setting and returns an endpoint for the proxy + + Args: + proxy (bytes|None): the proxy setting + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint + + Returns: + interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, + or None + """ + if proxy is None: + return None + + # currently we only support hostname:port. Some apps also support + # protocol://<host>[:port], which allows a way of requiring a TLS connection to the + # proxy. + + host, port = parse_host_port(proxy, default_port=1080) + return HostnameEndpoint(reactor, host, port, **kwargs) + + +def parse_host_port(hostport, default_port=None): + # could have sworn we had one of these somewhere else... + if b":" in hostport: + host, port = hostport.rsplit(b":", 1) + try: + port = int(port) + return host, port + except ValueError: + # the thing after the : wasn't a valid port; presumably this is an + # IPv6 address. + pass + + return hostport, default_port diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 46af27c8f6..58f9cc61c8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -170,7 +170,7 @@ class RequestMetrics(object): tag = context.tag if context != self.start_context: - logger.warn( + logger.warning( "Context have unexpectedly changed %r, %r", context, self.start_context, diff --git a/synapse/http/server.py b/synapse/http/server.py index 2ccb210fd6..943d12c907 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -454,7 +454,7 @@ def respond_with_json( # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: - logger.warn( + logger.warning( "Not sending response to request %s, already disconnected.", request ) return diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 274c1a6a87..e9a5e46ced 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False): try: content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: - logger.warn("Unable to decode UTF-8") + logger.warning("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) try: content = json.loads(content_unicode) except Exception as e: - logger.warn("Unable to parse JSON: %s", e) + logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content diff --git a/synapse/http/site.py b/synapse/http/site.py index df5274c177..ff8184a3d0 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -199,7 +199,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warn( + logger.warning( "Error processing request %r: %s %s", self, reason.type, reason.value ) @@ -305,7 +305,7 @@ class SynapseRequest(Request): try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: - logger.warn("Failed to stop metrics: %r", e) + logger.warning("Failed to stop metrics: %r", e) class XForwardedForRequest(SynapseRequest): diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 3220e985a9..334ddaf39a 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} def parse_drain_configs( - drains: dict + drains: dict, ) -> typing.Generator[DrainConfiguration, None, None]: """ Parse the drain configurations. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 370000e377..2c1fb9ddac 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -294,7 +294,7 @@ class LoggingContext(object): """Enters this logging context into thread local storage""" old_context = self.set_current_context(self) if self.previous_context != old_context: - logger.warn( + logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..af161a81d7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,7 @@ class Notifier(object): self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -425,7 +426,10 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 2bbdd11941..1ba7bcd4d8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object): room_members = yield self.store.get_joined_users_from_context(event, context) - (power_levels, sender_power_level) = ( - yield self._get_power_levels_and_sender_level(event, context) - ) + ( + power_levels, + sender_power_level, + ) = yield self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 42e5b0c0a5..8c818a86bf 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -234,14 +234,12 @@ class EmailPusher(object): return self.last_stream_ordering = last_stream_ordering - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..e994037be6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -64,6 +64,7 @@ class HttpPusher(object): def __init__(self, hs, pusherdict): self.hs = hs self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() self.user_id = pusherdict["user_name"] @@ -102,7 +103,7 @@ class HttpPusher(object): if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] @@ -210,14 +211,12 @@ class HttpPusher(object): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -246,7 +245,7 @@ class HttpPusher(object): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn( + logger.warning( "Giving up on a notification to user %s, " "pushkey %s", self.user_id, self.pushkey, @@ -299,7 +298,7 @@ class HttpPusher(object): if pk != self.pushkey: # for sanity, we only remove the pushkey if it # was the one we actually sent... - logger.warn( + logger.warning( ("Ignoring rejected pushkey %s because we" " didn't send it"), pk, ) @@ -329,7 +328,7 @@ class HttpPusher(object): return d ctx = yield push_tools.get_context_for_event( - self.store, self.state_handler, event, self.user_id + self.storage, self.state_handler, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5b16ab4ae8..1d15a06a58 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -119,6 +119,7 @@ class Mailer(object): self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() + self.storage = hs.get_storage() self.app_name = app_name logger.info("Created Mailer for app_name %s" % app_name) @@ -389,7 +390,7 @@ class Mailer(object): } the_events = yield filter_events_for_client( - self.store, user_id, results["events_before"] + self.storage, user_id, results["events_before"] ) the_events.append(notif_event) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 5ed9147de4..b1587183a8 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object): pattern = UserID.from_string(user_id).localpart if not pattern: - logger.warn("event_match condition with no pattern") + logger.warning("event_match condition with no pattern") return False # XXX: optimisation: cache our pattern regexps @@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False): regex_cache[(glob, word_boundary)] = r return r.search(value) except re.error: - logger.warn("Failed to parse glob to regex: %r", glob) + logger.warning("Failed to parse glob to regex: %r", glob) return False diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a54051a726..de5c101a58 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.push.presentable_names import calculate_room_name, name_from_member_event +from synapse.storage import Storage @defer.inlineCallbacks @@ -43,22 +44,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(store, state_handler, ev, user_id): +def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield store.get_state_ids_for_event(ev.event_id) + room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room name = yield calculate_room_name( - store, room_state_ids, user_id, fallback_to_single_member=False + storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield store.get_event(sender_state_event_id) + sender_state_event = yield storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 08e840fdc2..0f6992202d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -103,9 +103,7 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = ( - yield self.store.get_latest_push_action_stream_ordering() - ) + last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() yield self.store.add_pusher( user_id=user_id, diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 03560c1f0e..c8056b0c0c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -110,14 +110,14 @@ class ReplicationEndpoint(object): return {} @abc.abstractmethod - def _handle_request(self, request, **kwargs): + async def _handle_request(self, request, **kwargs): """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - Deferred[dict]: A JSON serialisable dict to be used as response - body of request. + tuple[int, dict]: HTTP status code and a JSON serialisable dict + to be used as response body of request. """ pass @@ -180,7 +180,7 @@ class ReplicationEndpoint(object): if e.code != 504 or not cls.RETRY_ON_TIMEOUT: raise - logger.warn("%s request timed out", cls.NAME) + logger.warning("%s request timed out", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 2f16955954..9af4e7e173 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request): + async def _handle_request(self, request): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = yield EventContext.deserialize( - self.store, event_payload["context"] - ) + context = EventContext.deserialize(self.store, event_payload["context"]) event_and_contexts.append((event, context)) logger.info("Got %d events from federation", len(event_and_contexts)) - yield self.federation_handler.persist_events_and_notify( + await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) @@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} - @defer.inlineCallbacks - def _handle_request(self, request, edu_type): + async def _handle_request(self, request, edu_type): with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = yield self.registry.on_edu(edu_type, origin, edu_content) + result = await self.registry.on_edu(edu_type, origin, edu_content) return 200, result @@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): """ return {"args": args} - @defer.inlineCallbacks - def _handle_request(self, request, query_type): + async def _handle_request(self, request, query_type): with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): logger.info("Got %r query", query_type) - result = yield self.registry.on_query(query_type, args) + result = await self.registry.on_query(query_type, args) return 200, result @@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): """ return {} - @defer.inlineCallbacks - def _handle_request(self, request, room_id): - yield self.store.clean_room_for_join(room_id) + async def _handle_request(self, request, room_id): + await self.store.clean_room_for_join(room_id) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 786f5232b2..798b9d3af5 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest ) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b9ce3477ad..cc1f249740 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID @@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - yield self.federation_handler.do_invite_join( + await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) @@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "remote_room_hosts": remote_room_hosts, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = yield self.federation_handler.do_remotely_reject_invite( + event = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() @@ -148,9 +144,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # The 'except' clause is very broad, but we need to # capture everything from DNS failures upwards # - logger.warn("Failed to reject invite: %s", e) + logger.warning("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite(user_id, room_id) + await self.store.locally_reject_invite(user_id, room_id) ret = {} return 200, ret diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 38260256cf..915cfb9430 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -74,11 +72,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "address": address, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.registration_handler.register_with_store( + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], was_guest=content["was_guest"], @@ -117,14 +114,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): """ return {"auth_result": auth_result, "access_token": access_token} - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) auth_result = content["auth_result"] access_token = content["access_token"] - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=user_id, auth_result=auth_result, access_token=access_token ) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index adb9b2f7f4..9bafd60b14 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request, event_id): + async def _handle_request(self, request, event_id): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) @@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = yield EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.store, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] @@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - yield self.event_creation_handler.persist_and_notify_client_event( + await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 61557665a7..de50748c30 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -15,6 +15,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() - result["device_lists"] = self._device_list_id_gen.get_current_token() + # The user signature stream uses the same stream ID generator as the + # device list stream, so set them both to the device list ID + # generator's current token. + current_token = self._device_list_id_gen.get_current_token() + result[DeviceListsStream.NAME] = current_token + result[UserSignatureStream.NAME] = current_token return result def process_replication_rows(self, stream_name, token, rows): - if stream_name == "device_lists": + if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) for row in rows: self._invalidate_caches_for_devices(token, row.user_id, row.destination) + elif stream_name == UserSignatureStream.NAME: + for row in rows: + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a44ceb00e7..563ce0fc53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -168,7 +168,7 @@ class ReplicationClientHandler(object): if self.connection: self.connection.send_command(cmd) else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) + logger.warning("Queuing command as not connected: %r", cmd.NAME) self.pending_commands.append(cmd) def send_federation_ack(self, token): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5ffdf2675d..b64f3f44b5 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -249,7 +249,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return handler(cmd) def close(self): - logger.warn("[%s] Closing connection", self.id()) + logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 634f636dc9..5f52264e84 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -45,5 +45,6 @@ STREAMS_MAP = { _base.TagAccountDataStream, _base.AccountDataStream, _base.GroupServerStream, + _base.UserSignatureStream, ) } diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..9e45429d49 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple( "GroupsStreamRow", ("group_id", "user_id", "type", "content"), # str # str # str # dict ) +UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str class Stream(object): @@ -438,3 +439,20 @@ class GroupServerStream(Stream): self.update_function = store.get_all_groups_changes super(GroupServerStream, self).__init__(hs) + + +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + NAME = "user_signature" + _LIMITED = False + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_user_signature_changes_for_remotes + + super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 939418ee2b..5c2a2eb593 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -286,7 +286,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, stream_ordering ) if not r: - logger.warn( + logger.warning( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", ts, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8414af08cb..24a0ce74f2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -203,10 +203,11 @@ class LoginRestServlet(RestServlet): address = address.lower() # Check for login providers that support 3pid login types - canonical_user_id, callback_3pid = ( - yield self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) + ( + canonical_user_id, + callback_3pid, + ) = yield self.auth_handler.check_password_provider_3pid( + medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded @@ -221,7 +222,7 @@ class LoginRestServlet(RestServlet): medium, address ) if not user_id: - logger.warn( + logger.warning( "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet): def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = ( - yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + token ) result = yield self._register_device_with_callback(user_id, login_submission) @@ -380,7 +381,7 @@ class CasTicketServlet(RestServlet): self.cas_displayname_attribute = hs.config.cas_displayname_attribute self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c1d41421c..86bbcc0eea 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet): set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet): def on_PUT_no_state_key(self, request, room_id, event_type): return self.on_PUT(request, room_id, event_type, "") - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_type, state_key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] ) msg_handler = self.message_handler - data = yield msg_handler.get_room_data( + data = await msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, @@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) if txn_id: set_tag("txn_id", txn_id) @@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.room_member_handler.update_membership( + event = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict = { @@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_identifier, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet): elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): server = parse_string(request, "server", default=None) try: - yield self.auth.get_user_by_req(request, allow_guest=True) + await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token ) return 200, data - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) @@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token, @@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): third_party_instance_id=third_party_instance_id, ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, @@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): + async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet): membership = parse_string(request, "membership") not_membership = parse_string(request, "not_membership") - events = yield handler.get_state_events( + events = await handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, @@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - users_with_profile = yield self.message_handler.get_joined_members( + users_with_profile = await self.message_handler.get_joined_members( requester, room_id ) @@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) @@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet): as_client_event = False else: event_filter = None - msgs = yield self.pagination_handler.get_messages( + msgs = await self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room - events = yield self.message_handler.get_state_events( + events = await self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.initial_sync_handler.room_initial_sync( + content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) return 200, content @@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: - event = yield self.event_handler.get_event( + event = await self.event_handler.get_event( requester.user, room_id, event_id ) except AuthError: @@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = yield self.room_context_handler.get_event_context( + results = await self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, event_filter ) @@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"] = await self._event_serializer.serialize_events( results["events_before"], time_now ) - results["event"] = yield self._event_serializer.serialize_event( + results["event"] = await self._event_serializer.serialize_event( results["event"], time_now ) - results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"] = await self._event_serializer.serialize_events( results["events_after"], time_now ) - results["state"] = yield self._event_serializer.serialize_events( + results["state"] = await self._event_serializer.serialize_events( results["state"], time_now ) @@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget(user=requester.user, room_id=room_id) + await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} @@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, @@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.room_member_handler.do_3pid_invite( + await self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content and membership_action in ["kick", "ban"]: event_content = {"reason": content["reason"]} - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) if content["typing"]: - yield self.typing_handler.started_typing( + await self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=timeout, ) else: - yield self.typing_handler.stopped_typing( + await self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id ) @@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = yield self.handlers.search_handler.search( + results = await self.handlers.search_handler.search( requester.user, content, batch ) @@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 80cf7126a0..f26eae794c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User password resets have been disabled due to lack of email config" ) raise SynapseError( @@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_password_reset_template_failure_html], ) @@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Password reset emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_add_threepid_template_failure_html], ) @@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index b3bf8567e1..67cbc37312 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) read_event_id = body.get("m.read", None) if read_event_id: - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), @@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet): read_marker_event_id = body.get("m.fully_read", None) if read_marker_event_id: - yield self.read_marker_handler.received_client_read_marker( + await self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), event_id=read_marker_event_id, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 0dab03d227..92555bd4a9 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet @@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id, receipt_type, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 4f24a124a6..91db923814 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Email registration has been disabled due to lack of email config" ) raise SynapseError( @@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) @@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User registration via email has been disabled due to lack of email config" ) raise SynapseError( @@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet): # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params - logger.warn("Ignoring initial_device_display_name without password") + logger.warning("Ignoring initial_device_display_name without password") del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 541a6b0e10..ccd8b17b23 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -394,7 +394,7 @@ class SyncRestServlet(RestServlet): # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: - logger.warn( + logger.warning( "Event %r is under room %r instead of %r", event.event_id, room.room_id, diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 55580bc59e..e7fc3f0431 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource): @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: - server, = request.postpath + (server,) = request.postpath query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index b972e152a9..bd9186fe50 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -363,7 +363,7 @@ class MediaRepository(object): }, ) except RequestSendFailed as e: - logger.warn( + logger.warning( "Request failed fetching remote media %s/%s: %r", server_name, media_id, @@ -372,7 +372,7 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn( + logger.warning( "HTTP error fetching remote media %s/%s: %s", server_name, media_id, @@ -383,10 +383,12 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) + logger.warning("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception( @@ -691,7 +693,7 @@ class MediaRepository(object): try: os.remove(full_path) except OSError as e: - logger.warn("Failed to remove file: %r", full_path) + logger.warning("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 094ebad770..531d923f76 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path @@ -136,7 +138,7 @@ class PreviewUrlResource(DirectServeResource): match = False continue if match: - logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) + logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) @@ -208,7 +210,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % url) + logger.warning("Couldn't get dims for %s" % url) # define our OG response for this media elif _is_html(media_info["media_type"]): @@ -256,7 +258,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s", og["og:image"]) + logger.warning("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -267,7 +269,7 @@ class PreviewUrlResource(DirectServeResource): else: del og["og:image"] else: - logger.warn("Failed to find any OG data in %s", url) + logger.warning("Failed to find any OG data in %s", url) og = {} logger.debug("Calculated OG for %s as %s", url, og) @@ -319,7 +321,7 @@ class PreviewUrlResource(DirectServeResource): ) except Exception as e: # FIXME: pass through 404s and other error messages nicely - logger.warn("Error downloading %s: %r", url, e) + logger.warning("Error downloading %s: %r", url, e) raise SynapseError( 500, @@ -400,7 +402,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -432,7 +434,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue try: @@ -448,7 +450,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 08329884ac..931ce79be8 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks @@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks diff --git a/synapse/server.py b/synapse/server.py index 1fcc7375d3..f8aeebcff8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,6 +23,7 @@ # Imports required for the default HomeServer() implementation import abc import logging +import os from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail @@ -95,6 +96,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler +from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -167,6 +169,7 @@ class HomeServer(object): "filtering", "http_client_context_factory", "simple_http_client", + "proxied_http_client", "media_repository", "media_repository_resource", "federation_transport_client", @@ -196,6 +199,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "storage", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -224,7 +228,7 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.datastore = None + self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: @@ -233,7 +237,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - self.datastore = self.DATASTORE_CLASS(conn, self) + datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(datastore, conn, self) conn.commit() logger.info("Finished setting up.") @@ -266,7 +271,7 @@ class HomeServer(object): return self.clock def get_datastore(self): - return self.datastore + return self.datastores.main def get_config(self): return self.config @@ -308,6 +313,13 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) + def build_proxied_http_client(self): + return SimpleHttpClient( + self, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), + ) + def build_room_creation_handler(self): return RoomCreationHandler(self) @@ -537,6 +549,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_storage(self) -> Storage: + return Storage(self, self.datastores) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 16f8f6b573..b5e0b57095 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -12,6 +12,7 @@ import synapse.handlers.message import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password +import synapse.http.client import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -38,8 +39,16 @@ class HomeServer(object): pass def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass + def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + or support any HTTP_PROXY settings""" + pass + def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + but does support HTTP_PROXY settings""" + pass def get_deactivate_account_handler( - self + self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: @@ -47,32 +56,32 @@ class HomeServer(object): def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass def get_event_creation_handler( - self + self, ) -> synapse.handlers.message.EventCreationHandler: pass def get_set_password_handler( - self + self, ) -> synapse.handlers.set_password.SetPasswordHandler: pass def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass def get_federation_transport_client( - self + self, ) -> synapse.federation.transport.client.TransportLayerClient: pass def get_media_repository_resource( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass def get_media_repository( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass def get_server_notices_manager( - self + self, ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass def get_server_notices_sender( - self + self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index c0e7f475c9..9fae2e0afe 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -83,7 +83,7 @@ class ResourceLimitsServerNotices(object): room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) if not room_id: - logger.warn("Failed to get server notices room") + logger.warning("Failed to get server notices room") return yield self._check_and_set_tags(user_id, room_id) diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py new file mode 100644 index 0000000000..efcc10f808 --- /dev/null +++ b/synapse/spam_checker_api/__init__.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +class SpamCheckerApi(object): + """A proxy object that gets passed to spam checkers so they can get + access to rooms and other relevant information. + """ + + def __init__(self, hs): + self.hs = hs + + self._store = hs.get_datastore() + + @defer.inlineCallbacks + def get_state_events_in_room(self, room_id, types): + """Gets state events for the given room. + + Args: + room_id (string): The room ID to get state events in. + types (tuple): The event type and state key (using None + to represent 'any') of the room state to acquire. + + Returns: + twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: + The filtered state events in the room. + """ + state_ids = yield self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) + state = yield self._store.get_events(state_ids.values()) + return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dc9f5a9008..4e91eb66fe 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -103,6 +103,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -271,7 +272,7 @@ class StateHandler(object): else: current_state_ids = prev_state_ids - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=None, @@ -321,7 +322,7 @@ class StateHandler(object): delta_ids = dict(entry.delta_ids) delta_ids[key] = event.event_id - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -334,7 +335,7 @@ class StateHandler(object): delta_ids = entry.delta_ids if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( + entry.state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=entry.prev_group, @@ -376,14 +377,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a249ecd219..0a1a8cc1e5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -27,7 +27,26 @@ data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from synapse.storage.data_stores.main import DataStore # noqa: F401 +from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main import DataStore +from synapse.storage.persist_events import EventsPersistenceStorage +from synapse.storage.state import StateGroupStorage + +__all__ = ["DataStores", "DataStore"] + + +class Storage(object): + """The high level interfaces for talking to various storage layers. + """ + + def __init__(self, hs, stores: DataStores): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.persistence = EventsPersistenceStorage(hs, stores) + self.state = StateGroupStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f5906fcd54..1a2b7ebe25 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -494,7 +494,7 @@ class SQLBaseStore(object): exception_callbacks = [] if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn("Starting db txn '%s' from sentinel context", desc) + logger.warning("Starting db txn '%s' from sentinel context", desc) try: result = yield self.runWithConnection( @@ -532,7 +532,7 @@ class SQLBaseStore(object): """ parent_context = LoggingContext.current_context() if parent_context == LoggingContext.sentinel: - logger.warn( + logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None @@ -719,7 +719,7 @@ class SQLBaseStore(object): raise # presumably we raced with another transaction: let's retry. - logger.warn( + logger.warning( "IntegrityError when upserting into %s; retrying: %s", table, e ) diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 56094078ed..cb184a98cc 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -12,3 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +class DataStores(object): + """The various data stores. + + These are low level interfaces to physical databases. + """ + + def __init__(self, main_store, db_conn, hs): + # Note we pass in the main store here as workers use a different main + # store. + self.main = main_store diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index b185ba0b3e..10c940df1e 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -139,7 +139,10 @@ class DataStore( db_conn, "public_room_list_stream", "stream_id" ) self._device_list_id_gen = StreamIdGenerator( - db_conn, "device_lists_stream", "stream_id" + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[("user_signature_stream", "stream_id")], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" @@ -317,7 +320,7 @@ class DataStore( ) u """ txn.execute(sql, (time_from,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count def count_r30_users(self): @@ -396,7 +399,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - count, = txn.fetchone() + (count,) = txn.fetchone() results["all"] = count return results diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index f7a3542348..71f62036c0 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -37,6 +37,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList @@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit): - """Get stream of updates to send to remote servers + def get_device_updates_by_remote(self, destination, from_stream_id, limit): + """Get a stream of device updates to send to the given remote server. + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + limit (int): Maximum number of device updates to return Returns: - Deferred[tuple[int, list[dict]]]: + Deferred[tuple[int, list[tuple[string,dict]]]]: current stream id (ie, the stream id of the last update included in the - response), and the list of updates + response), and the list of updates, where each update is a pair of EDU + type and EDU contents """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore): # stream_id; the rationale being that such a large device list update # is likely an error. updates = yield self.runInteraction( - "get_devices_by_remote", - self._get_devices_by_remote_txn, + "get_device_updates_by_remote", + self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, @@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore): if not updates: return now_stream_id, [] + # get the cross-signing keys of the users in the list, so that we can + # determine which of the device changes were cross-signing keys + users = set(r[0] for r in updates) + master_key_by_user = {} + self_signing_key_by_user = {} + for user in users: + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + master_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + + cross_signing_key = yield self.get_e2e_cross_signing_key( + user, "self_signing" + ) + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. if len(updates) > limit: @@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore): # context which created the Edu. query_map = {} - for update in updates: - if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: + cross_signing_keys_by_user = {} + for user_id, device_id, update_stream_id, update_context in updates: + if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff: # Stop processing updates break - key = (update[0], update[1]) - - update_context = update[3] - update_stream_id = update[2] + if ( + user_id in master_key_by_user + and device_id == master_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = master_key_by_user[user_id]["key_info"] + elif ( + user_id in self_signing_key_by_user + and device_id == self_signing_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["self_signing_key"] = self_signing_key_by_user[user_id][ + "key_info" + ] + else: + key = (user_id, device_id) - previous_update_stream_id, _ = query_map.get(key, (0, None)) + previous_update_stream_id, _ = query_map.get(key, (0, None)) - if update_stream_id > previous_update_stream_id: - query_map[key] = (update_stream_id, update_context) + if update_stream_id > previous_update_stream_id: + query_map[key] = (update_stream_id, update_context) # If we didn't find any updates with a stream_id lower than the cutoff, it # means that there are more than limit updates all of which have the same @@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore): # devices, in which case E2E isn't going to work well anyway. We'll just # skip that stream_id and return an empty list, and continue with the next # stream_id next time. - if not query_map: + if not query_map and not cross_signing_keys_by_user: return stream_id_cutoff, [] results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) + # add the updated cross-signing keys to the results list + for user_id, result in iteritems(cross_signing_keys_by_user): + result["user_id"] = user_id + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("org.matrix.signing_key_update", result)) + return now_stream_id, results - def _get_devices_by_remote_txn( + def _get_device_updates_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): """Return device update information for a given remote destination @@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: List: List of device updates """ + # get the list of device updates that need to be sent sql = """ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? @@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore): List[Dict]: List of objects representing an device update EDU """ - devices = yield self.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, + devices = ( + yield self.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, + ) + if query_map + else {} ) results = [] @@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore): else: result["deleted"] = True - results.append(result) + results.append(("m.device_list_update", result)) return results diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index a0bc6f2d18..073412a78d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): + """Return a list of changes from the user signature stream to notify remotes. + Note that the user signature stream represents when a user signs their + device with their user-signing key, which is not published to other + users or servers, so no `destination` is needed in the returned + list. However, this is needed to poke workers. + + Args: + from_key (int): the stream ID to start at (exclusive) + to_key (int): the stream ID to end at (inclusive) + + Returns: + Deferred[list[(int,str)]] a list of `(stream_id, user_id)` + """ + sql = """ + SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id + """ + return self._execute( + "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 22025effbc..04ce21ac66 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) stream_row = txn.fetchone() if stream_row: - offset_stream_ordering, = stream_row + (offset_stream_ordering,) = stream_row rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 1045c7fa2e..301f8ea128 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,14 +17,14 @@ import itertools import logging -from collections import Counter as c_counter, OrderedDict, deque, namedtuple +from collections import Counter as c_counter, OrderedDict, namedtuple from functools import wraps from six import iteritems, text_type from six.moves import range from canonicaljson import json -from prometheus_client import Counter, Histogram +from prometheus_client import Counter from twisted.internet import defer @@ -34,11 +34,9 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event_dict -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore @@ -46,10 +44,8 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -60,37 +56,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -# The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") - -# The number of times we are recalculating state when there is only a -# single forward extremity -state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" -) - -# The number of times we are reculating state when we could have resonably -# calculated the delta when we calculated the state for an event we were -# persisting. -state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" -) - -# The number of forward extremities for each new event. -forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", - "Number of forward extremities for each new event", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -# The number of stale forward extremities for each new event. Stale extremities -# are those that were in the previous set of extremities as well as the new. -stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", - "Number of unchanged forward extremities for each new event", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - def encode_json(json_object): """ @@ -102,110 +67,6 @@ def encode_json(json_object): return out -class _EventPeristenceQueue(object): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. - """ - - _EventPersistQueueItem = namedtuple( - "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") - ) - - def __init__(self): - self._event_persist_queues = {} - self._currently_persisting_rooms = set() - - def add_to_queue(self, room_id, events_and_contexts, backfilled): - """Add events to the queue, with the given persist_event options. - - NB: due to the normal usage pattern of this method, it does *not* - follow the synapse logcontext rules, and leaves the logcontext in - place whether or not the returned deferred is ready. - - Args: - room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - - Returns: - defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. - """ - queue = self._event_persist_queues.setdefault(room_id, deque()) - if queue: - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - end_item = queue[-1] - if end_item.backfilled == backfilled: - end_item.events_and_contexts.extend(events_and_contexts) - return end_item.deferred.observe() - - deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - - queue.append( - self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - ) - ) - - return deferred.observe() - - def handle_queue(self, room_id, per_item_callback): - """Attempts to handle the queue for a room if not already being handled. - - The given callback will be invoked with for each item in the queue, - of type _EventPersistQueueItem. The per_item_callback will continuously - be called with new items, unless the queue becomnes empty. The return - value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferreds as well. - - This function should therefore be called whenever anything is added - to the queue. - - If another callback is currently handling the queue then it will not be - invoked. - """ - - if room_id in self._currently_persisting_rooms: - return - - self._currently_persisting_rooms.add(room_id) - - @defer.inlineCallbacks - def handle_queue_loop(): - try: - queue = self._get_drainining_queue(room_id) - for item in queue: - try: - ret = yield per_item_callback(item) - except Exception: - with PreserveLoggingContext(): - item.deferred.errback() - else: - with PreserveLoggingContext(): - item.deferred.callback(ret) - finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue - self._currently_persisting_rooms.discard(room_id) - - # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) - - def _get_drainining_queue(self, room_id): - queue = self._event_persist_queues.setdefault(room_id, deque()) - - try: - while True: - yield queue.popleft() - except IndexError: - # Queue has been drained. - pass - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -221,7 +82,7 @@ def _retry_on_integrity_error(func): @defer.inlineCallbacks def f(self, *args, **kwargs): try: - res = yield func(self, *args, **kwargs) + res = yield func(self, *args, delete_existing=False, **kwargs) except self.database_engine.module.IntegrityError: logger.exception("IntegrityError, retrying.") res = yield func(self, *args, delete_existing=True, **kwargs) @@ -241,9 +102,6 @@ class EventsStore( def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() - # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count self._current_forward_extremities_amount = c_counter() @@ -286,340 +144,106 @@ class EventsStore( res = yield self.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) - @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False): - """ - Write events to the database - Args: - events_and_contexts: list of tuples of (event, context) - backfilled (bool): Whether the results are retrieved from federation - via backfill or not. Used to determine if they're "new" events - which might update the current state etc. - - Returns: - Deferred[int]: the stream ordering of the latest persisted event - """ - partitioned = {} - for event, ctx in events_and_contexts: - partitioned.setdefault(event.room_id, []).append((event, ctx)) - - deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): - d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled - ) - deferreds.append(d) - - for room_id in partitioned: - self._maybe_start_persisting(room_id) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - - return max_persisted_id - - @defer.inlineCallbacks - @log_function - def persist_event(self, event, context, backfilled=False): - """ - - Args: - event (EventBase): - context (EventContext): - backfilled (bool): - - Returns: - Deferred: resolves to (int, int): the stream ordering of ``event``, - and the stream ordering of the latest persisted event - """ - deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - - self._maybe_start_persisting(event.room_id) - - yield make_deferred_yieldable(deferred) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) - - def _maybe_start_persisting(self, room_id): - @defer.inlineCallbacks - def persisting_queue(item): - with Measure(self._clock, "persist_events"): - yield self._persist_events( - item.events_and_contexts, backfilled=item.backfilled - ) - - self._event_persist_queue.handle_queue(room_id, persisting_queue) - @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events( - self, events_and_contexts, backfilled=False, delete_existing=False + def _persist_events_and_state_updates( + self, + events_and_contexts, + current_state_for_room, + state_delta_for_room, + new_forward_extremeties, + backfilled=False, + delete_existing=False, ): - """Persist events to db + """Persist a set of events alongside updates to the current state and + forward extremities tables. Args: events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): + current_state_for_room (dict[str, dict]): Map from room_id to the + current state of the room based on forward extremities + state_delta_for_room (dict[str, tuple]): Map from room_id to tuple + of `(to_delete, to_insert)` where to_delete is a list + of type/state keys to remove from current state, and to_insert + is a map (type,key)->event_id giving the state delta in each + room. + new_forward_extremities (dict[str, list[str]]): Map from room_id + to list of event IDs that are the new forward extremities of + the room. + backfilled (bool) delete_existing (bool): Returns: Deferred: resolves when the events have been persisted """ - if not events_and_contexts: - return - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) + ) - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + delete_existing=delete_existing, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + ) + persist_event_counter.inc(len(events_and_contexts)) if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - - # We want to calculate the stream orderings as late as possible, as - # we only notify after all events with a lesser stream ordering have - # been persisted. I.e. if we spend 10s inside the with block then - # that will delay all subsequent events from being notified about. - # Hence why we do it down here rather than wrapping the entire - # function. - # - # Its safe to do this after calculating the state deltas etc as we - # only need to protect the *persistence* of the events. This is to - # ensure that queries of the form "fetch events since X" don't - # return events and stream positions after events that are still in - # flight, as otherwise subsequent requests "fetch event since Y" - # will not return those events. - # - # Note: Multiple instances of this function cannot be in flight at - # the same time for the same room. - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(chunk) + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + events_and_contexts[-1][0].internal_metadata.stream_ordering ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(chunk, stream_orderings): - event.internal_metadata.stream_ordering = stream - - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, - ) - persist_event_counter.inc(len(chunk)) - - if not backfilled: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering - ) - for event, context in chunk: - if context.app_service: - origin_type = "local" - origin_entity = context.app_service.id - elif self.hs.is_mine_id(event.sender): - origin_type = "local" - origin_entity = "*client*" - else: - origin_type = "remote" - origin_entity = get_domain_from_id(event.sender) - - event_counter.labels(event.type, origin_type, origin_entity).inc() - - for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) - - for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) - ) - - @defer.inlineCallbacks - def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremities for a room given events to - persist. - - Assumes that we are only persisting events for one room at a time. - """ - - # we're only interested in new events which aren't outliers and which aren't - # being rejected. - new_events = [ - event - for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() - and not ctx.rejected - and not event.internal_metadata.is_soft_failed() - ] - - latest_event_ids = set(latest_event_ids) - - # start with the existing forward extremities - result = set(latest_event_ids) - - # add all the new events to the list - result.update(event.event_id for event in new_events) - - # Now remove all events which are prev_events of any of the new events - result.difference_update( - e_id for event in new_events for e_id in event.prev_event_ids() - ) + for event, context in events_and_contexts: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) - # Remove any events which are prev_events of any existing events. - existing_prevs = yield self._get_events_which_are_prevs(result) - result.difference_update(existing_prevs) + event_counter.labels(event.type, origin_type, origin_entity).inc() - # Finally handle the case where the new events have soft-failed prev - # events. If they do we need to remove them and their prev events, - # otherwise we end up with dangling extremities. - existing_prevs = yield self._get_prevs_before_rejected( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - result.difference_update(existing_prevs) + for room_id, new_state in iteritems(current_state_for_room): + self.get_current_state_ids.prefill((room_id,), new_state) - # We only update metrics for events that change forward extremities - # (e.g. we ignore backfill/outliers/etc) - if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) - stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) - - return result + for room_id, latest_event_ids in iteritems(new_forward_extremeties): + self.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) @defer.inlineCallbacks def _get_events_which_are_prevs(self, event_ids): @@ -725,188 +349,6 @@ class EventsStore( return existing_prevs - @defer.inlineCallbacks - def _get_new_state_after_events( - self, room_id, events_context, old_latest_event_ids, new_latest_event_ids - ): - """Calculate the current state dict after adding some new events to - a room - - Args: - room_id (str): - room to which the events are being added. Used for logging etc - - events_context (list[(EventBase, EventContext)]): - events and contexts which are being added to the room - - old_latest_event_ids (iterable[str]): - the old forward extremities for the room. - - new_latest_event_ids (iterable[str]): - the new forward extremities for the room. - - Returns: - Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. - - If there has been a change then we only return the delta if its - already been calculated. Conversely if we do know the delta then - the new current state is only returned if we've already calculated - it. - """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - - # Map from (prev state group, new state group) -> delta state dict - state_group_deltas = {} - - for ev, ctx in events_context: - if ctx.state_group is None: - # This should only happen for outlier events. - if not ev.internal_metadata.is_outlier(): - raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) - ) - continue - - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids - - # We need to map the event_ids to their state groups. First, let's - # check if the event is one we're persisting, in which case we can - # pull the state group from its context. - # Otherwise we need to pull the state group from the database. - - # Set of events we need to fetch groups for. (We know none of the old - # extremities are going to be in events_context). - missing_event_ids = set(old_latest_event_ids) - - event_id_to_state_group = {} - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding. - for ev, ctx in events_context: - if event_id == ev.event_id and ctx.state_group is not None: - event_id_to_state_group[event_id] = ctx.state_group - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - missing_event_ids.add(event_id) - - if missing_event_ids: - # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events(missing_event_ids) - event_id_to_state_group.update(event_to_groups) - - # State groups of old_latest_event_ids - old_state_groups = set( - event_id_to_state_group[evid] for evid in old_latest_event_ids - ) - - # State groups of new_latest_event_ids - new_state_groups = set( - event_id_to_state_group[evid] for evid in new_latest_event_ids - ) - - # If they old and new groups are the same then we don't need to do - # anything. - if old_state_groups == new_state_groups: - return None, None - - if len(new_state_groups) == 1 and len(old_state_groups) == 1: - # If we're going from one state group to another, lets check if - # we have a delta for that transition. If we do then we can just - # return that. - - new_state_group = next(iter(new_state_groups)) - old_state_group = next(iter(old_state_groups)) - - delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) - if delta_ids is not None: - # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids - - # Now that we have calculated new_state_groups we need to get - # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = yield self._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) - - if len(new_state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_groups_map[new_state_groups.pop()], None - - # Ok, we need to defer to the state handler to resolve our state sets. - - state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} - - events_map = {ev.event_id: ev for ev, _ in events_context} - - # We need to get the room version, which is in the create event. - # Normally that'd be in the database, but its also possible that we're - # currently trying to persist it. - room_version = None - for ev, _ in events_context: - if ev.type == EventTypes.Create and ev.state_key == "": - room_version = ev.content.get("room_version", "1") - break - - if not room_version: - room_version = yield self.get_room_version(room_id) - - logger.debug("calling resolve_state_groups from preserve_events") - res = yield self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_groups, - events_map, - state_res_store=StateResolutionStore(self), - ) - - return res.state, None - - @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, current_state): - """Calculate the new state deltas for a room. - - Assumes that we are only persisting events for one room at a time. - - Returns: - tuple[list, dict] (to_delete, to_insert): where to_delete are the - type/state_keys to remove from current_state_events and `to_insert` - are the updates to current_state_events. - """ - existing_state = yield self.get_current_state_ids(room_id) - - to_delete = [key for key in existing_state if key not in current_state] - - to_insert = { - key: ev_id - for key, ev_id in iteritems(current_state) - if ev_id != existing_state.get(key) - } - - return to_delete, to_insert - @log_function def _persist_events_txn( self, @@ -1690,7 +1132,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_messages", _count_messages) @@ -1711,7 +1153,7 @@ class EventsStore( """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) @@ -1726,7 +1168,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_active_rooms", _count) @@ -2211,7 +1653,7 @@ class EventsStore( """, (room_id,), ) - min_depth, = txn.fetchone() + (min_depth,) = txn.fetchone() logger.info("[purge] updating room_depth to %d", min_depth) @@ -2403,7 +1845,6 @@ class EventsStore( "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", - "topics", "users_in_public_rooms", "users_who_share_private_rooms", # no useful index, but let's clear them anyway @@ -2446,12 +1887,11 @@ class EventsStore( logger.info("[purge] done") - @defer.inlineCallbacks - def is_event_after(self, event_id1, event_id2): + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = yield self._get_event_ordering(event_id1) - to_2, so_2 = yield self._get_event_ordering(event_id2) + to_1, so_1 = await self._get_event_ordering(event_id1) + to_2, so_2 = await self._get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 31ea6f917f..51352b9966 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if not rows: return 0 - upper_event_id, = rows[-1] + (upper_event_id,) = rows[-1] # Update the redactions with the received_ts. # diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index aeae5a2b28..b3a2771f1b 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index e6ee1e4aaa..b41c3d317a 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" txn.execute(sql) - count, = txn.fetchone() + (count,) = txn.fetchone() return count return self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index cd95f1ce60..b520062d84 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -143,7 +143,7 @@ class PushRulesWorkerStore( " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() + (count,) = txn.fetchone() return bool(count) return self.runInteraction( diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f005c1ae0a..d76861cdc0 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore): r["data"] = json.loads(dataJson) except Exception as e: - logger.warn( + logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], dataJson, diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 6c5b29288a..f70d41ecab 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE appservice_id IS NULL """ ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index bc04bfd7d4..2af24a20b7 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): if not row or not row[0]: return processed, True - next_room, = row + (next_room,) = row sql = """ UPDATE current_state_events diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite new file mode 100644 index 0000000000..e8b1fd35d8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite @@ -0,0 +1,42 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change the hidden column from a default value of FALSE to a default value of + * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the + * string 'FALSE', which is truthy. + * + * Since sqlite doesn't allow us to just change the default value, we have to + * recreate the table, copy the data, fix the rows that have incorrect data, and + * replace the old table with the new table. + */ + +CREATE TABLE IF NOT EXISTS devices2 ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + last_seen BIGINT, + ip TEXT, + user_agent TEXT, + hidden BOOLEAN DEFAULT 0, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); + +INSERT INTO devices2 SELECT * FROM devices; + +UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE'; + +DROP TABLE devices; + +ALTER TABLE devices2 RENAME TO devices; diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 0e08497452..d1d7c6863d 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): " ON event_search USING GIN (vector)" ) except psycopg2.ProgrammingError as e: - logger.warn( + logger.warning( "Ignoring error %r when trying to switch from GIST to GIN", e ) @@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore): ) ) txn.execute(query, (value, search_query)) - headline, = txn.fetchall()[0] + (headline,) = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline # result. diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9b2207075b..3132848034 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -725,16 +725,18 @@ class StateGroupWorkerStore( member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - non_member_state, incomplete_groups_nm, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter ) - member_state, incomplete_groups_m, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter ) state = dict(non_member_state) @@ -1076,7 +1078,7 @@ class StateBackgroundUpdateStore( " WHERE id < ? AND room_id = ?", (state_group, room_id), ) - prev_group, = txn.fetchone() + (prev_group,) = txn.fetchone() new_last_state_group = state_group if prev_group: diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 4d59b7833f..45b3de7d56 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore): (room_id,), ) - current_state_events_count, = txn.fetchone() + (current_state_events_count,) = txn.fetchone() users_in_room = self.get_users_in_room_txn(txn, room_id) @@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count, pos joined_rooms, pos = yield self.runInteraction( diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py new file mode 100644 index 0000000000..fa03ca9ff7 --- /dev/null +++ b/synapse/storage/persist_events.py @@ -0,0 +1,649 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import deque, namedtuple + +from six import iteritems +from six.moves import range + +from prometheus_client import Counter, Histogram + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore +from synapse.storage.data_stores import DataStores +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") + +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "" +) + +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "" +) + +# The number of forward extremities for each new event. +forward_extremities_counter = Histogram( + "synapse_storage_events_forward_extremities_persisted", + "Number of forward extremities for each new event", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +# The number of stale forward extremities for each new event. Stale extremities +# are those that were in the previous set of extremities as well as the new. +stale_forward_extremities_counter = Histogram( + "synapse_storage_events_stale_forward_extremities_persisted", + "Number of unchanged forward extremities for each new event", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + + +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled): + """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. + end_item = queue[-1] + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) + + return deferred.observe() + + def handle_queue(self, room_id, per_item_callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with for each item in the queue, + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomnes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferreds as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + @defer.inlineCallbacks + def handle_queue_loop(): + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + ret = yield per_item_callback(item) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: + with PreserveLoggingContext(): + item.deferred.callback(ret) + finally: + queue = self._event_persist_queues.pop(room_id, None) + if queue: + self._event_persist_queues[room_id] = queue + self._currently_persisting_rooms.discard(room_id) + + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +class EventsPersistenceStorage(object): + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + + def __init__(self, hs, stores: DataStores): + # We ultimately want to split out the state store from the main store, + # so we use separate variables here even though they point to the same + # store for now. + self.main_store = stores.main + self.state_store = stores.main + + self._clock = hs.get_clock() + self.is_mine_id = hs.is_mine_id + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + + @defer.inlineCallbacks + def persist_events(self, events_and_contexts, backfilled=False): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled (bool): Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in iteritems(partitioned): + d = self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, backfilled=backfilled + ) + deferreds.append(d) + + for room_id in partitioned: + self._maybe_start_persisting(room_id) + + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) + + max_persisted_id = yield self.main_store.get_current_events_token() + + return max_persisted_id + + @defer.inlineCallbacks + def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) + + self._maybe_start_persisting(event.room_id) + + yield make_deferred_yieldable(deferred) + + max_persisted_id = yield self.main_store.get_current_events_token() + return (event.internal_metadata.stream_ordering, max_persisted_id) + + def _maybe_start_persisting(self, room_id): + @defer.inlineCallbacks + def persisting_queue(item): + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, backfilled=item.backfilled + ) + + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + @defer.inlineCallbacks + def _persist_events(self, events_and_contexts, backfilled=False): + """Calculates the change to current state and forward extremities, and + persists the given events and with those updates. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ + if not events_and_contexts: + return + + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = yield self.main_store.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, "persist_events.calculate_state_delta" + ): + delta = yield self._calculate_state_delta( + room_id, current_state + ) + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + yield self.main_store._persist_events_and_state_updates( + chunk, + current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + backfilled=backfilled, + ) + + @defer.inlineCallbacks + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected + and not event.internal_metadata.is_soft_failed() + ] + + latest_event_ids = set(latest_event_ids) + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update(event.event_id for event in new_events) + + # Now remove all events which are prev_events of any of the new events + result.difference_update( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + + # Remove any events which are prev_events of any existing events. + existing_prevs = yield self.main_store._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) + + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self.main_store._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + # (e.g. we ignore backfill/outliers/etc) + if result != latest_event_ids: + forward_extremities_counter.observe(len(result)) + stale = latest_event_ids & result + stale_forward_extremities_counter.observe(len(stale)) + + return result + + @defer.inlineCallbacks + def _get_new_state_after_events( + self, room_id, events_context, old_latest_event_ids, new_latest_event_ids + ): + """Calculate the current state dict after adding some new events to + a room + + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. + + Returns: + Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: + Returns a tuple of two state maps, the first being the full new current + state and the second being the delta to the existing current state. + If both are None then there has been no change. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. + """ + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id,) + ) + continue + + if ctx.state_group in state_groups_map: + continue + + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding. + for ev, ctx in events_context: + if event_id == ev.event_id and ctx.state_group is not None: + event_id_to_state_group[event_id] = ctx.state_group + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + missing_event_ids.add(event_id) + + if missing_event_ids: + # Now pull out the state groups for any missing events from DB + event_to_groups = yield self.main_store._get_state_group_for_events( + missing_event_ids + ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) + + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) + + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return None, None + + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. + + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) + + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. If we happen to already have + # the current state in memory then lets also return that, + # but it doesn't matter if we don't. + new_state = state_groups_map.get(new_state_group) + return new_state, delta_ids + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self.state_store._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_groups_map[new_state_groups.pop()], None + + # Ok, we need to defer to the state handler to resolve our state sets. + + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} + + events_map = {ev.event_id: ev for ev, _ in events_context} + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = yield self.main_store.get_room_version(room_id) + + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self.main_store), + ) + + return res.state, None + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + tuple[list, dict] (to_delete, to_insert): where to_delete are the + type/state_keys to remove from current_state_events and `to_insert` + are the updates to current_state_events. + """ + existing_state = yield self.main_store.get_current_state_ids(room_id) + + to_delete = [key for key in existing_state if key not in current_state] + + to_insert = { + key: ev_id + for key, ev_id in iteritems(current_state) + if ev_id != existing_state.get(key) + } + + return to_delete, to_insert diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a2df8fa827..3735846899 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from six import iteritems, itervalues import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -322,3 +324,234 @@ class StateFilter(object): ) return member_filter, non_member_filter + + +class StateGroupStorage(object): + """High level interface to fetching state for event. + """ + + def __init__(self, hs, stores): + self.stores = stores + + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): + (prev_group, delta_ids) + """ + + return self.stores.main.get_state_group_delta(state_group) + + @defer.inlineCallbacks + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + if not event_ids: + return {} + + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups(groups) + + return group_to_state + + @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the event IDs of all the state in the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + return group_to_state[state_group] + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.stores.main.get_events( + [ + ev_id + for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + return self.stores.main._get_state_groups_from_groups(groups, state_filter) + + @defer.inlineCallbacks + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + Args: + event_ids (list[string]) + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + """ + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) + + state_event_map = yield self.stores.main.get_events( + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in iteritems(group_to_state[group]) + if v in state_event_map + } + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids(list(str)): events whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from event_id -> (type, state_key) -> event_id + """ + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_for_events([event_id], state_filter) + return state_map[event_id] + + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], state_filter) + return state_map[event_id] + + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + return self.stores.main._get_state_for_groups(groups, state_filter) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + return self.stores.main.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index cbb0a4810a..9d851beaa5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - val, = cur.fetchone() + (val,) = cur.fetchone() cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 804dbca443..5c4de2e69f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -86,11 +86,12 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underdlying deferred. """ if not self._result: d = defer.Deferred() @@ -105,7 +106,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers @@ -138,7 +139,7 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. + func (func): Function to execute, should return a deferred or coroutine. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. @@ -148,11 +149,10 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass @@ -309,7 +309,7 @@ class Linearizer(object): ) else: - logger.warn( + logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", self.name, key, diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 43fd65d693..da5077b471 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None): if collect_callback: collect_callback() except Exception as e: - logger.warn("Error calculating metrics for %s: %s", cache_name, e) + logger.warning("Error calculating metrics for %s: %s", cache_name, e) raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b1bcdf23c..3286804322 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -119,7 +119,7 @@ class Measure(object): context = LoggingContext.current_context() if context != self.start_context: - logger.warn( + logger.warning( "Context has unexpectedly changed from '%s' to '%s'. (%r)", self.start_context, context, @@ -128,7 +128,7 @@ class Measure(object): return if not context: - logger.warn("Expected context. (%r)", self.name) + logger.warning("Expected context. (%r)", self.name) return current = context.get_resource_usage() @@ -140,7 +140,7 @@ class Measure(object): block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warn( + logger.warning( "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index 6c0f2bb0cf..207cd17c2a 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no): resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) ) except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) + logger.warning("Failed to set file or core limit: %s", e) diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index fa404b9d75..ab7d03af3a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -42,6 +42,7 @@ def get_version_string(module): try: null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: git_branch = ( subprocess.check_output( @@ -51,7 +52,8 @@ def get_version_string(module): .decode("ascii") ) git_branch = "b=" + git_branch - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed git_branch = "" try: @@ -63,7 +65,7 @@ def get_version_string(module): .decode("ascii") ) git_tag = "t=" + git_tag - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_tag = "" try: @@ -74,7 +76,7 @@ def get_version_string(module): .strip() .decode("ascii") ) - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_commit = "" try: @@ -89,7 +91,7 @@ def get_version_string(module): ) git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: diff --git a/synapse/visibility.py b/synapse/visibility.py index bf0f1eebd8..8c843febd8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage import Storage from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id @@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - store, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() ): """ Check which events a user is allowed to see Args: - store (synapse.storage.DataStore): our datastore (can also be a worker - store) + storage user_id(str): user id to be checked events(list[synapse.events.EventBase]): sequence of events to be checked is_peeking(bool): should be True if: @@ -68,12 +68,12 @@ def filter_events_for_client( events = list(e for e in events if not e.internal_metadata.is_soft_failed()) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield store.get_state_for_events( + event_id_to_state = yield storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -84,7 +84,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) def allowed(event): """ @@ -213,13 +213,17 @@ def filter_events_for_client( @defer.inlineCallbacks def filter_events_for_server( - store, server_name, events, redact=True, check_history_visibility_only=False + storage: Storage, + server_name, + events, + redact=True, + check_history_visibility_only=False, ): """Filter a list of events based on whether given server is allowed to see them. Args: - store (DataStore) + storage server_name (str) events (iterable[FrozenEvent]) redact (bool): Whether to return a redacted version of the event, or @@ -274,7 +278,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -292,14 +296,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield store.get_events(visibility_ids) + event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in itervalues(event_map) ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -328,7 +332,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -358,7 +362,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield store.get_events( + event_map = yield storage.main.get_events( [ e_id for e_id, key in iteritems(event_id_to_state_key) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index c4f0bbd3dd..8efd39c7f7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 67f1013051..5ec568f4e6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", - "get_devices_by_remote", + "get_device_updates_by_remote", # Bits that user_directory needs "get_user_directory_stream_pos", "get_current_state_deltas", @@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_devices_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = (0, []) def get_received_txn_response(*args): return defer.succeed(None) @@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2d5dba6464..2096ba3c91 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -20,6 +20,23 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection from twisted.internet.interfaces import IOpenSSLServerConnectionCreator +from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 +from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 + + +def get_test_https_policy(): + """Get a test IPolicyForHTTPS which trusts the test CA cert + + Returns: + IPolicyForHTTPS + """ + ca_file = get_test_ca_cert_file() + with open(ca_file) as stream: + content = stream.read() + cert = Certificate.loadPEM(content) + trust_root = trustRootFromCertificates([cert]) + return BrowserLikePolicyForHTTPS(trustRoot=trust_root) def get_test_ca_cert_file(): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 71d7025264..cfcd98ff7d 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase): FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) + # grab a hold of the TLS connection, in case it gets torn down + server_tls_connection = server_tls_protocol._tlsConnection + + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_tls_protocol.wrappedProtocol + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) # check the SNI - server_name = server_tls_protocol._tlsConnection.get_servername() + server_name = server_tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) - # fish the test server back out of the server-side TLS protocol. - return server_tls_protocol.wrappedProtocol + return http_protocol @defer.inlineCallbacks def _make_get_request(self, uri): diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py new file mode 100644 index 0000000000..22abf76515 --- /dev/null +++ b/tests/http/test_proxyagent.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import treq + +from twisted.internet import interfaces # noqa: F401 +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel + +from synapse.http.proxyagent import ProxyAgent + +from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.unittest import TestCase + +logger = logging.getLogger(__name__) + +HTTPFactory = Factory.forProtocol(HTTPChannel) + + +class MatrixFederationAgentTests(TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def _make_connection( + self, client_factory, server_factory, ssl=False, expected_sni=None + ): + """Builds a test server, and completes the outgoing client connection + + Args: + client_factory (interfaces.IProtocolFactory): the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + server_factory (interfaces.IProtocolFactory): a factory to build the + server-side protocol + + ssl (bool): If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + + expected_sni (bytes|None): the expected SNI value + + Returns: + IProtocol: the server Protocol returned by server_factory + """ + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory) + + server_protocol = server_factory.buildProtocol(None) + + # now, tell the client protocol factory to build the client protocol, + # and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_protocol, self.reactor, client_protocol) + ) + + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) + ) + + if ssl: + http_protocol = server_protocol.wrappedProtocol + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None + + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) + + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + return http_protocol + + def test_http_request(self): + agent = ProxyAgent(self.reactor) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 80) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request(self): + agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + _get_test_protocol_factory(), + ssl=True, + expected_sni=b"test.com", + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_http_request_via_proxy(self): + agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy(self): + agent = ProxyAgent( + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + +def _wrap_server_factory_for_tls(factory, sanlist=None): + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` + + Args: + factory (interfaces.IProtocolFactory): protocol factory to wrap + sanlist (iterable[bytes]): list of domains the cert should be valid for + + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [b"DNS:test.com"] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + + +def _get_test_protocol_factory(): + """Get a protocol Factory which will build an HTTPChannel + + Returns: + interfaces.IProtocolFactory + """ + server_factory = Factory.forProtocol(HTTPChannel) + + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + return server_factory + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8ce6bb62da..af2327fb66 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, simple_http_client=m) + hs = self.setup_test_homeserver(config=config, proxied_http_client=m) return hs diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 104349cdbd..4f924ce451 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() + self.storage = hs.get_storage() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index a368117b43..b68e9fe082 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.get_success( + self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + ) self.replicate() event_source = RoomEventSource(self.hs) @@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) + self.storage.persistence.persist_events( + [(event, context)], backfilled=True + ) ) else: - self.get_success(self.master_store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index d3a4f717f7..8e1ca8b738 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -561,3 +561,81 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + ): + count = self.get_success( + self.store._simple_select_one_onecol( + table="events", + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) diff --git a/tests/server.py b/tests/server.py index e397ebe8fa..f878aeaada 100644 --- a/tests/server.py +++ b/tests/server.py @@ -161,7 +161,11 @@ def make_request( path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand - if shorthand and not path.startswith(b"/_matrix"): + if ( + shorthand + and not path.startswith(b"/_matrix") + and not path.startswith(b"/_synapse") + ): path = b"/_matrix/client/r0/" + path path = path.replace(b"//", b"/") @@ -391,11 +395,24 @@ class FakeTransport(object): self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) - self.disconnected = True + + # if we still have data to write, delay until that is done + if self.buffer: + logger.info( + "FakeTransport: Delaying disconnect until buffer is flushed" + ) + else: + self.disconnected = True def abortConnection(self): logger.info("FakeTransport: abortConnection()") - self.loseConnection() + + if not self.disconnecting: + self.disconnecting = True + if self._protocol: + self._protocol.connectionLost(None) + + self.disconnected = True def pauseProducing(self): if not self.producer: @@ -426,6 +443,9 @@ class FakeTransport(object): self._reactor.callLater(0.0, _produce) def write(self, byt): + if self.disconnecting: + raise Exception("Writing to disconnecting FakeTransport") + self.buffer = self.buffer + byt # always actually do the write asynchronously. Some protocols (notably the @@ -470,6 +490,10 @@ class FakeTransport(object): if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + if not self.buffer and self.disconnecting: + logger.info("FakeTransport: Buffer now empty, completing disconnect") + self.disconnected = True + def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: """ diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index dd49a14524..9b81b536f5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 3cc18f9f1c..6f8d990959 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks - def test_get_devices_by_remote(self): + def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id @@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "somehost", -1, limit=100 ) @@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self._check_devices_in_updates(device_ids, device_updates) @defer.inlineCallbacks - def test_get_devices_by_remote_limited(self): + def test_get_device_updates_by_remote_limited(self): # Test breaking the update limit in 1, 101, and 1 device_id segments # first add one device @@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # # first we should get a single update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) @@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) - received_device_ids = {update["device_id"] for update in device_updates} + received_device_ids = { + update["device_id"] for edu_type, update in device_updates + } self.assertEqual(received_device_ids, set(expected_device_ids)) @defer.inlineCallbacks diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 427d3c49c5..4561c3e383 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.store.persist_event(event_1, context_1)) + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.store.persist_event(event_2, context_2)) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1bee45706f..3ddaa151fe 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 447a3c6ffb..9ddd17f73d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5c2cf3c2db..43200654f1 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_datastore = self.store self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -63,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @@ -82,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) @@ -101,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -141,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -157,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -181,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -199,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -215,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -237,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -250,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -267,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -287,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -304,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -317,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -331,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -346,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) @@ -369,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -381,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -394,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -405,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -424,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -435,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -448,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -459,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/test_federation.py b/tests/test_federation.py index a73f18f88e..7d82b58466 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -58,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase): ) self.handler = self.homeserver.get_handlers().federation_handler - self.handler.do_auth = lambda *a, **b: succeed(True) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context + ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus @@ -75,7 +78,8 @@ class MessageAcceptTests(unittest.TestCase): self.assertEqual( self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0], "$join:test.serv", @@ -97,7 +101,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -137,6 +142,6 @@ class MessageAcceptTests(unittest.TestCase): # Make sure the invalid event isn't there extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") diff --git a/tests/test_state.py b/tests/test_state.py index 610ec9fb46..38246555bd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -158,10 +158,12 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() + storage = Mock(main=self.store, state=self.store) hs = Mock( spec_set=[ "config", "get_datastore", + "get_storage", "get_auth", "get_state_handler", "get_clock", @@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) + hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 18f1a0035d..f7381b2885 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -14,6 +14,8 @@ # limitations under the License. import logging +from mock import Mock + from twisted.internet import defer from twisted.internet.defer import succeed @@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") @@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) # the result should be 5 redacted events, and 5 unredacted events. @@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # ... and the filtering happens. filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) for i in range(0, len(events_to_filter)): @@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield self.event_creation_handler.create_new_client_event( builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): logger.info("Starting filtering") start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + filtered = yield filter_events_for_server( test_store, "test_server", events_to_filter ) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index f907903511..39e360fe24 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.return_value = ["spam", "eggs"] r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = ["chips"] r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() diff --git a/tests/utils.py b/tests/utils.py index 8cced4b7e8..7dc9bdc505 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -325,10 +325,16 @@ def setup_test_homeserver( if homeserverToUse.__name__ == "TestHomeServer": hs.setup_master() else: + # If we have been given an explicit datastore we probably want to mock + # out the DataStores somehow too. This all feels a bit wrong, but then + # mocking the stores feels wrong too. + datastores = Mock(datastore=datastore) + hs = homeserverToUse( name, db_pool=None, datastore=datastore, + datastores=datastores, config=config, version_string="Synapse/tests", database_engine=db_engine, @@ -646,7 +652,7 @@ def create_room(hs, room_id, creator_id): creator_id (str) """ - store = hs.get_datastore() + persistence_store = hs.get_storage().persistence event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() @@ -663,4 +669,4 @@ def create_room(hs, room_id, creator_id): event, context = yield event_creation_handler.create_new_client_event(builder) - yield store.persist_event(event, context) + yield persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index e3a53f340a..afe9bc909b 100644 --- a/tox.ini +++ b/tox.ini @@ -114,16 +114,16 @@ skip_install = True basepython = python3.6 deps = flake8 - black==19.3b0 # We pin so that our tests don't start failing on new releases of black. + black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . - /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" + /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort] skip_install = True deps = isort -commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests" +commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts" [testenv:check-newsfragment] skip_install = True @@ -167,6 +167,6 @@ deps = env = MYPYPATH = stubs/ extras = all -commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \ +commands = mypy \ synapse/logging/ \ synapse/config/ |