diff --git a/UPGRADE.rst b/UPGRADE.rst
index 6492fa011f..fc8982ddfe 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,6 +75,45 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.21.0
+====================
+
+Forwarding ``/_synapse/client`` through your reverse proxy
+----------------------------------------------------------
+
+The `reverse proxy documentation
+<https://github.com/matrix-org/synapse/blob/develop/docs/reverse_proxy.md>`_ has been updated
+to include reverse proxy directives for ``/_synapse/client/*`` endpoints. As the user password
+reset flow now uses endpoints under this prefix, **you must update your reverse proxy
+configurations for user password reset to work**.
+
+Additionally, note that the `Synapse worker documentation
+<https://github.com/matrix-org/synapse/blob/develop/docs/workers.md>`_ has been updated to
+ state that the ``/_synapse/client/password_reset/email/submit_token`` endpoint can be handled
+by all workers. If you make use of Synapse's worker feature, please update your reverse proxy
+configuration to reflect this change.
+
+New HTML templates
+------------------
+
+A new HTML template,
+`password_reset_confirmation.html <https://github.com/matrix-org/synapse/blob/develop/synapse/res/templates/password_reset_confirmation.html>`_,
+has been added to the ``synapse/res/templates`` directory. If you are using a
+custom template directory, you may want to copy the template over and modify it.
+
+Note that as of v1.20.0, templates do not need to be included in custom template
+directories for Synapse to start. The default templates will be used if a custom
+template cannot be found.
+
+This page will appear to the user after clicking a password reset link that has
+been emailed to them.
+
+To complete password reset, the page must include a way to make a `POST`
+request to
+``/_synapse/client/password_reset/{medium}/submit_token``
+with the query parameters from the original link, presented as a URL-encoded form. See the file
+itself for more details.
+
Upgrading to v1.18.0
====================
diff --git a/changelog.d/7124.bugfix b/changelog.d/7124.bugfix
new file mode 100644
index 0000000000..8fd177780d
--- /dev/null
+++ b/changelog.d/7124.bugfix
@@ -0,0 +1 @@
+Fix a bug in the media repository where remote thumbnails with the same size but different crop methods would overwrite each other. Contributed by @deepbluev7.
diff --git a/changelog.d/7796.bugfix b/changelog.d/7796.bugfix
new file mode 100644
index 0000000000..65e5eb42a2
--- /dev/null
+++ b/changelog.d/7796.bugfix
@@ -0,0 +1 @@
+Fix inconsistent handling of non-existent push rules, and stop tracking the `enabled` state of removed push rules.
diff --git a/changelog.d/8004.feature b/changelog.d/8004.feature
new file mode 100644
index 0000000000..a91b75e0e0
--- /dev/null
+++ b/changelog.d/8004.feature
@@ -0,0 +1 @@
+Require the user to confirm that their password should be reset after clicking the email confirmation link.
\ No newline at end of file
diff --git a/changelog.d/8208.misc b/changelog.d/8208.misc
new file mode 100644
index 0000000000..e65da88c46
--- /dev/null
+++ b/changelog.d/8208.misc
@@ -0,0 +1 @@
+Fix tests on distros which disable TLSv1.0. Contributed by @danc86.
diff --git a/changelog.d/8216.misc b/changelog.d/8216.misc
new file mode 100644
index 0000000000..b38911b0e5
--- /dev/null
+++ b/changelog.d/8216.misc
@@ -0,0 +1 @@
+Simplify the distributor code to avoid unnecessary work.
diff --git a/changelog.d/8227.doc b/changelog.d/8227.doc
new file mode 100644
index 0000000000..4a43015a83
--- /dev/null
+++ b/changelog.d/8227.doc
@@ -0,0 +1 @@
+Add `/_synapse/client` to the reverse proxy documentation.
diff --git a/changelog.d/8230.misc b/changelog.d/8230.misc
new file mode 100644
index 0000000000..bf0ba76730
--- /dev/null
+++ b/changelog.d/8230.misc
@@ -0,0 +1 @@
+Track the latest event for every destination and room for catch-up after federation outage.
diff --git a/changelog.d/8236.bugfix b/changelog.d/8236.bugfix
new file mode 100644
index 0000000000..6f04871015
--- /dev/null
+++ b/changelog.d/8236.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug where files that could not be thumbnailed would result in an Internal Server Error.
diff --git a/changelog.d/8243.misc b/changelog.d/8243.misc
new file mode 100644
index 0000000000..f7375d32d3
--- /dev/null
+++ b/changelog.d/8243.misc
@@ -0,0 +1 @@
+Remove the 'populate_stats_process_rooms_2' background job and restore functionality to 'populate_stats_process_rooms'.
\ No newline at end of file
diff --git a/changelog.d/8247.misc b/changelog.d/8247.misc
new file mode 100644
index 0000000000..3c27803be4
--- /dev/null
+++ b/changelog.d/8247.misc
@@ -0,0 +1 @@
+Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage.
diff --git a/changelog.d/8250.misc b/changelog.d/8250.misc
new file mode 100644
index 0000000000..b6896a9300
--- /dev/null
+++ b/changelog.d/8250.misc
@@ -0,0 +1 @@
+Clean up type hints for `PaginationConfig`.
diff --git a/changelog.d/8256.misc b/changelog.d/8256.misc
new file mode 100644
index 0000000000..bf0ba76730
--- /dev/null
+++ b/changelog.d/8256.misc
@@ -0,0 +1 @@
+Track the latest event for every destination and room for catch-up after federation outage.
diff --git a/changelog.d/8257.misc b/changelog.d/8257.misc
new file mode 100644
index 0000000000..47ac583eb4
--- /dev/null
+++ b/changelog.d/8257.misc
@@ -0,0 +1 @@
+Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`.
diff --git a/changelog.d/8258.misc b/changelog.d/8258.misc
new file mode 100644
index 0000000000..3c27803be4
--- /dev/null
+++ b/changelog.d/8258.misc
@@ -0,0 +1 @@
+Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage.
diff --git a/changelog.d/8259.misc b/changelog.d/8259.misc
new file mode 100644
index 0000000000..a26779a664
--- /dev/null
+++ b/changelog.d/8259.misc
@@ -0,0 +1 @@
+Switch to the JSON implementation from the standard library.
diff --git a/changelog.d/8260.misc b/changelog.d/8260.misc
new file mode 100644
index 0000000000..164eea8b59
--- /dev/null
+++ b/changelog.d/8260.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.util.async_helpers`.
diff --git a/changelog.d/8261.misc b/changelog.d/8261.misc
new file mode 100644
index 0000000000..bc91e9375c
--- /dev/null
+++ b/changelog.d/8261.misc
@@ -0,0 +1 @@
+Simplify tests that mock asynchronous functions.
diff --git a/changelog.d/8262.bugfix b/changelog.d/8262.bugfix
new file mode 100644
index 0000000000..2b84927de3
--- /dev/null
+++ b/changelog.d/8262.bugfix
@@ -0,0 +1 @@
+Upgrade canonicaljson to version 1.4.0Â to fix an unicode encoding issue.
diff --git a/changelog.d/8265.bugfix b/changelog.d/8265.bugfix
new file mode 100644
index 0000000000..981a836d21
--- /dev/null
+++ b/changelog.d/8265.bugfix
@@ -0,0 +1 @@
+Fix logstanding bug which could lead to incomplete database upgrades on SQLite.
diff --git a/changelog.d/8268.bugfix b/changelog.d/8268.bugfix
new file mode 100644
index 0000000000..4b15a60253
--- /dev/null
+++ b/changelog.d/8268.bugfix
@@ -0,0 +1 @@
+Fix stack overflow when stderr is redirected to the logging system, and the logging system encounters an error.
diff --git a/changelog.d/8275.feature b/changelog.d/8275.feature
new file mode 100644
index 0000000000..17549c3df3
--- /dev/null
+++ b/changelog.d/8275.feature
@@ -0,0 +1 @@
+Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number.
\ No newline at end of file
diff --git a/changelog.d/8278.bugfix b/changelog.d/8278.bugfix
new file mode 100644
index 0000000000..50e40ca2a9
--- /dev/null
+++ b/changelog.d/8278.bugfix
@@ -0,0 +1 @@
+Fix a bug which cause the logging system to report errors, if `DEBUG` was enabled and no `context` filter was applied.
diff --git a/changelog.d/8279.misc b/changelog.d/8279.misc
new file mode 100644
index 0000000000..99f669001f
--- /dev/null
+++ b/changelog.d/8279.misc
@@ -0,0 +1 @@
+Add type hints to `StreamToken` and `RoomStreamToken` classes.
diff --git a/changelog.d/8281.misc b/changelog.d/8281.misc
new file mode 100644
index 0000000000..74357120a7
--- /dev/null
+++ b/changelog.d/8281.misc
@@ -0,0 +1 @@
+Change `StreamToken.room_key` to be a `RoomStreamToken` instance.
diff --git a/changelog.d/8282.misc b/changelog.d/8282.misc
new file mode 100644
index 0000000000..b6896a9300
--- /dev/null
+++ b/changelog.d/8282.misc
@@ -0,0 +1 @@
+Clean up type hints for `PaginationConfig`.
diff --git a/changelog.d/8285.misc b/changelog.d/8285.misc
new file mode 100644
index 0000000000..4646664ba1
--- /dev/null
+++ b/changelog.d/8285.misc
@@ -0,0 +1 @@
+Blacklist [MSC2753](https://github.com/matrix-org/matrix-doc/pull/2753) SyTests until it is implemented.
\ No newline at end of file
diff --git a/changelog.d/8287.bugfix b/changelog.d/8287.bugfix
new file mode 100644
index 0000000000..839781aa07
--- /dev/null
+++ b/changelog.d/8287.bugfix
@@ -0,0 +1 @@
+Fix edge case where push could get delayed for a user until a later event was pushed.
diff --git a/changelog.d/8288.misc b/changelog.d/8288.misc
new file mode 100644
index 0000000000..c08a53a5ee
--- /dev/null
+++ b/changelog.d/8288.misc
@@ -0,0 +1 @@
+Refactor notifier code to correctly use the max event stream position.
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index fd48ba0874..edd109fa7b 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -11,7 +11,7 @@ privileges.
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
the requested URI in any way (for example, by decoding `%xx` escapes).
-Beware that Apache *will* canonicalise URIs unless you specifify
+Beware that Apache *will* canonicalise URIs unless you specify
`nocanon`.
When setting up a reverse proxy, remember that Matrix clients and other
@@ -23,6 +23,10 @@ specification](https://matrix.org/docs/spec/server_server/latest#resolving-serve
for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation.
+Endpoints that are part of the standardised Matrix specification are
+located under `/_matrix`, whereas endpoints specific to Synapse are
+located under `/_synapse/client`.
+
Let's assume that we expect clients to connect to our server at
`https://matrix.example.com`, and other servers to connect at
`https://example.com:8448`. The following sections detail the configuration of
@@ -45,7 +49,7 @@ server {
server_name matrix.example.com;
- location /_matrix {
+ location ~* ^(\/_matrix|\/_synapse\/client) {
proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr;
# Nginx by default only allows file uploads up to 1M in size
@@ -65,6 +69,10 @@ matrix.example.com {
proxy /_matrix http://localhost:8008 {
transparent
}
+
+ proxy /_synapse/client http://localhost:8008 {
+ transparent
+ }
}
example.com:8448 {
@@ -79,6 +87,7 @@ example.com:8448 {
```
matrix.example.com {
reverse_proxy /_matrix/* http://localhost:8008
+ reverse_proxy /_synapse/client/* http://localhost:8008
}
example.com:8448 {
@@ -96,6 +105,8 @@ example.com:8448 {
AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+ ProxyPass /_synapse/client http://127.0.0.1:8008/_synapse/client nocanon
+ ProxyPassReverse /_synapse/client http://127.0.0.1:8008/_synapse/client
</VirtualHost>
<VirtualHost *:8448>
@@ -119,6 +130,7 @@ frontend https
# Matrix client traffic
acl matrix-host hdr(host) -i matrix.example.com
acl matrix-path path_beg /_matrix
+ acl matrix-path path_beg /_synapse/client
use_backend matrix if matrix-host matrix-path
@@ -146,3 +158,10 @@ connecting to Synapse from a client.
Synapse exposes a health check endpoint for use by reverse proxies.
Each configured HTTP listener has a `/health` endpoint which always returns
200 OK (and doesn't get logged).
+
+## Synapse administration endpoints
+
+Endpoints for administering your Synapse instance are placed under
+`/_synapse/admin`. These require authentication through an access token of an
+admin user. However as access to these endpoints grants the caller a lot of power,
+we do not recommend exposing them to the public internet without good reason.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 3528d9e11f..2a5b2e0935 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -432,6 +432,24 @@ retention:
#
#request_token_inhibit_3pid_errors: true
+# A list of domains that the domain portion of 'next_link' parameters
+# must match.
+#
+# This parameter is optionally provided by clients while requesting
+# validation of an email or phone number, and maps to a link that
+# users will be automatically redirected to after validation
+# succeeds. Clients can make use this parameter to aid the validation
+# process.
+#
+# The whitelist is applied whether the homeserver or an
+# identity server is handling validation.
+#
+# The default value is no whitelist functionality; all domains are
+# allowed. Setting this value to an empty list will instead disallow
+# all domains.
+#
+#next_link_domain_whitelist: ["matrix.org"]
+
## TLS ##
@@ -2021,9 +2039,13 @@ email:
# * The contents of password reset emails sent by the homeserver:
# 'password_reset.html' and 'password_reset.txt'
#
- # * HTML pages for success and failure that a user will see when they follow
- # the link in the password reset email: 'password_reset_success.html' and
- # 'password_reset_failure.html'
+ # * An HTML page that a user will see when they follow the link in the password
+ # reset email. The user will be asked to confirm the action before their
+ # password is reset: 'password_reset_confirmation.html'
+ #
+ # * HTML pages for success and failure that a user will see when they confirm
+ # the password reset flow using the page above: 'password_reset_success.html'
+ # and 'password_reset_failure.html'
#
# * The contents of address verification emails sent during registration:
# 'registration.html' and 'registration.txt'
diff --git a/docs/workers.md b/docs/workers.md
index bfec745897..df0ac84d94 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -217,6 +217,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
+ ^/_synapse/client/password_reset/email/submit_token$
# Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$
diff --git a/mypy.ini b/mypy.ini
index 7764f17856..7986781432 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -34,7 +34,7 @@ files =
synapse/http/federation/well_known_resolver.py,
synapse/http/server.py,
synapse/http/site.py,
- synapse/logging/,
+ synapse/logging,
synapse/metrics,
synapse/module_api,
synapse/notifier.py,
@@ -46,14 +46,17 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
+ synapse/storage/databases/main/events.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
+ synapse/storage/persist_events.py,
synapse/storage/state.py,
synapse/storage/util,
synapse/streams,
synapse/types.py,
+ synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 2a2c9e6f13..bb33345be6 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,10 +15,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
from typing import List
import jsonschema
-from canonicaljson import json
from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index bbfccf955e..6379c86dde 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -21,6 +21,7 @@ from urllib.parse import urlencode
from synapse.config import ConfigError
+SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
CLIENT_API_PREFIX = "/_matrix/client"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index b6c9085670..7d309b1bb0 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
+import json
import logging
import os
import sys
import tempfile
-from canonicaljson import json
-
from twisted.internet import defer, task
import synapse
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6014adc850..b08319ca77 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -48,6 +48,7 @@ from synapse.api.urls import (
from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
@@ -209,6 +210,15 @@ class SynapseHomeServer(HomeServer):
resources["/_matrix/saml2"] = SAML2Resource(self)
+ if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ from synapse.rest.synapse.client.password_reset import (
+ PasswordResetSubmitTokenResource,
+ )
+
+ resources[
+ "/_synapse/client/password_reset/email/submit_token"
+ ] = PasswordResetSubmitTokenResource(self)
+
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 7a796996c0..72b42bfd62 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -228,6 +228,7 @@ class EmailConfig(Config):
self.email_registration_template_text,
self.email_add_threepid_template_html,
self.email_add_threepid_template_text,
+ self.email_password_reset_template_confirmation_html,
self.email_password_reset_template_failure_html,
self.email_registration_template_failure_html,
self.email_add_threepid_template_failure_html,
@@ -242,6 +243,7 @@ class EmailConfig(Config):
registration_template_text,
add_threepid_template_html,
add_threepid_template_text,
+ "password_reset_confirmation.html",
password_reset_template_failure_html,
registration_template_failure_html,
add_threepid_template_failure_html,
@@ -404,9 +406,13 @@ class EmailConfig(Config):
# * The contents of password reset emails sent by the homeserver:
# 'password_reset.html' and 'password_reset.txt'
#
- # * HTML pages for success and failure that a user will see when they follow
- # the link in the password reset email: 'password_reset_success.html' and
- # 'password_reset_failure.html'
+ # * An HTML page that a user will see when they follow the link in the password
+ # reset email. The user will be asked to confirm the action before their
+ # password is reset: 'password_reset_confirmation.html'
+ #
+ # * HTML pages for success and failure that a user will see when they confirm
+ # the password reset flow using the page above: 'password_reset_success.html'
+ # and 'password_reset_failure.html'
#
# * The contents of address verification emails sent during registration:
# 'registration.html' and 'registration.txt'
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index c96e6ef62a..13d6f6a3ea 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -17,6 +17,7 @@ import logging
import logging.config
import os
import sys
+import threading
from string import Template
import yaml
@@ -25,6 +26,7 @@ from twisted.logger import (
ILogObserver,
LogBeginner,
STDLibLogObserver,
+ eventAsText,
globalLogBeginner,
)
@@ -216,8 +218,9 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
# system.
observer = STDLibLogObserver()
- def _log(event):
+ threadlocal = threading.local()
+ def _log(event):
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
@@ -228,7 +231,25 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
if event["log_text"].startswith("Timing out client"):
return
- return observer(event)
+ # this is a workaround to make sure we don't get stack overflows when the
+ # logging system raises an error which is written to stderr which is redirected
+ # to the logging system, etc.
+ if getattr(threadlocal, "active", False):
+ # write the text of the event, if any, to the *real* stderr (which may
+ # be redirected to /dev/null, but there's not much we can do)
+ try:
+ event_text = eventAsText(event)
+ print("logging during logging: %s" % event_text, file=sys.__stderr__)
+ except Exception:
+ # gah.
+ pass
+ return
+
+ try:
+ threadlocal.active = True
+ return observer(event)
+ finally:
+ threadlocal.active = False
logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
if not config.no_redirect_stdio:
diff --git a/synapse/config/server.py b/synapse/config/server.py
index e85c6a0840..532b910470 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Set
import attr
import yaml
@@ -542,6 +542,19 @@ class ServerConfig(Config):
users_new_default_push_rules
) # type: set
+ # Whitelist of domain names that given next_link parameters must have
+ next_link_domain_whitelist = config.get(
+ "next_link_domain_whitelist"
+ ) # type: Optional[List[str]]
+
+ self.next_link_domain_whitelist = None # type: Optional[Set[str]]
+ if next_link_domain_whitelist is not None:
+ if not isinstance(next_link_domain_whitelist, list):
+ raise ConfigError("'next_link_domain_whitelist' must be a list")
+
+ # Turn the list into a set to improve lookup speed.
+ self.next_link_domain_whitelist = set(next_link_domain_whitelist)
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
@@ -1014,6 +1027,24 @@ class ServerConfig(Config):
# act as if no error happened and return a fake session ID ('sid') to clients.
#
#request_token_inhibit_3pid_errors: true
+
+ # A list of domains that the domain portion of 'next_link' parameters
+ # must match.
+ #
+ # This parameter is optionally provided by clients while requesting
+ # validation of an email or phone number, and maps to a link that
+ # users will be automatically redirected to after validation
+ # succeeds. Clients can make use this parameter to aid the validation
+ # process.
+ #
+ # The whitelist is applied whether the homeserver or an
+ # identity server is handling validation.
+ #
+ # The default value is no whitelist functionality; all domains are
+ # allowed. Setting this value to an empty list will instead disallow
+ # all domains.
+ #
+ #next_link_domain_whitelist: ["matrix.org"]
"""
% locals()
)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 552519e82c..41a726878d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -209,7 +209,7 @@ class FederationSender:
logger.debug("Sending %s to %r", event, destinations)
if destinations:
- self._send_pdu(event, destinations)
+ await self._send_pdu(event, destinations)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@@ -265,7 +265,7 @@ class FederationSender:
finally:
self._is_processing = False
- def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
@@ -280,6 +280,13 @@ class FederationSender:
sent_pdus_destination_dist_total.inc(len(destinations))
sent_pdus_destination_dist_count.inc()
+ # track the fact that we have a PDU for these destinations,
+ # to allow us to perform catch-up later on if the remote is unreachable
+ # for a while.
+ await self.store.store_destination_rooms_entries(
+ destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+ )
+
for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index defc228c23..9f0852b4a2 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -325,6 +325,17 @@ class PerDestinationQueue:
self._last_device_stream_id = device_stream_id
self._last_device_list_stream_id = dev_list_id
+
+ if pending_pdus:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ final_pdu = pending_pdus[-1]
+ last_successful_stream_ordering = (
+ final_pdu.internal_metadata.stream_ordering
+ )
+ await self._store.set_destination_last_successful_stream_ordering(
+ self._destination, last_successful_stream_ordering
+ )
else:
break
except NotRetryingDestination as e:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 918d0e037c..5e5a64037d 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
else:
stream_ordering = room.stream_ordering
- from_key = str(RoomStreamToken(0, 0))
- to_key = str(RoomStreamToken(None, stream_ordering))
+ from_key = RoomStreamToken(0, 0)
+ to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events:
break
- from_key = events[-1].internal_metadata.after
+ from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
events = await filter_events_for_client(self.storage, user_id, events)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 643d71a710..4b0a4f96cc 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
RoomStreamToken,
+ StreamToken,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
@trace
@measure_func("device.get_user_ids_changed")
- async def get_user_ids_changed(self, user_id, from_token):
+ async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
-
- Args:
- user_id (str)
- from_token (StreamToken)
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
- now_room_key = await self.store.get_room_events_max_id()
+ now_room_id = self.store.get_room_max_stream_ordering()
+ now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
)
rooms_changed.update(event.room_id for event in member_events)
- stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
+ stream_ordering = from_token.room_key.stream
possibly_changed = set(changed)
possibly_left = set()
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b05e32f457..fdce54c5c3 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -39,10 +39,6 @@ class EventStreamHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super(EventStreamHandler, self).__init__(hs)
- self.distributor = hs.get_distributor()
- self.distributor.declare("started_user_eventstream")
- self.distributor.declare("stopped_user_eventstream")
-
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 43f2986f89..c195eba830 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -69,7 +69,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet,
)
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
@@ -80,7 +79,6 @@ from synapse.types import (
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
@@ -130,7 +128,6 @@ class FederationHandler(BaseHandler):
self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
- self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._message_handler = hs.get_message_handler()
@@ -141,9 +138,6 @@ class FederationHandler(BaseHandler):
self._replication = hs.get_replication_data_handler()
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
- self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
- hs
- )
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
)
@@ -704,31 +698,10 @@ class FederationHandler(BaseHandler):
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
- context = await self._handle_new_event(origin, event, state=state)
+ await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally
- # joined the room. Don't bother if the user is just
- # changing their profile info.
- newly_joined = True
-
- prev_state_ids = await context.get_prev_state_ids()
-
- prev_state_id = prev_state_ids.get((event.type, event.state_key))
- if prev_state_id:
- prev_state = await self.store.get_event(
- prev_state_id, allow_none=True
- )
- if prev_state and prev_state.membership == Membership.JOIN:
- newly_joined = False
-
- if newly_joined:
- user = UserID.from_string(event.state_key)
- await self.user_joined_room(user, room_id)
-
# For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted:
@@ -1550,11 +1523,6 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- if event.type == EventTypes.Member:
- if event.content["membership"] == Membership.JOIN:
- user = UserID.from_string(event.state_key)
- await self.user_joined_room(user, event.room_id)
-
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
@@ -2970,8 +2938,6 @@ class FederationHandler(BaseHandler):
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
@@ -2984,16 +2950,6 @@ class FederationHandler(BaseHandler):
else:
await self.store.clean_room_for_join(room_id)
- async def user_joined_room(self, user: UserID, room_id: str) -> None:
- """Called when a new user has joined the room
- """
- if self.config.worker_app:
- await self._notify_user_membership_change(
- room_id=room_id, user_id=user.to_string(), change="joined"
- )
- else:
- user_joined_room(self.distributor, user, room_id)
-
async def get_room_complexity(
self, remote_room_hosts: List[str], room_id: str
) -> Optional[dict]:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d5ddc583ad..ba4828c713 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
- pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = await presence_stream.get_pagination_rows(
- user, pagination_config.get_source_config("presence"), None
+ presence, _ = await presence_stream.get_new_events(
+ user, from_key=None, include_offline=False
)
- receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = await receipt_stream.get_pagination_rows(
- user, pagination_config.get_source_config("receipt"), None
+ joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
+ receipt = await self.store.get_linearized_receipts_for_rooms(
+ joined_rooms, to_key=int(now_token.receipt_key),
)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -168,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
- room_end_token = "s%d" % (event.stream_ordering,)
+ room_end_token = RoomStreamToken(None, event.stream_ordering,)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 8a7b4916cd..e54e2b322b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -387,8 +387,6 @@ class EventCreationHandler:
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
- self.pusher_pool = hs.get_pusherpool()
-
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
@@ -975,6 +973,7 @@ class EventCreationHandler:
This should only be run on the instance in charge of persisting events.
"""
assert self._is_event_writer
+ assert self.storage.persistence is not None
if ratelimit:
# We check if this is a room admin redacting an event so that we
@@ -1145,8 +1144,6 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
def _notify():
try:
self.notifier.on_new_room_event(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 34ed0e2921..d929a68f7d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -335,20 +335,16 @@ class PaginationHandler:
user_id = requester.user.to_string()
if pagin_config.from_token:
- room_token = pagin_config.from_token.room_key
+ from_token = pagin_config.from_token
else:
- pagin_config.from_token = (
- self.hs.get_event_sources().get_current_token_for_pagination()
- )
- room_token = pagin_config.from_token.room_key
-
- room_token = RoomStreamToken.parse(room_token)
+ from_token = self.hs.get_event_sources().get_current_token_for_pagination()
- pagin_config.from_token = pagin_config.from_token.copy_and_replace(
- "room_key", str(room_token)
- )
+ if pagin_config.limit is None:
+ # This shouldn't happen as we've set a default limit before this
+ # gets called.
+ raise Exception("limit not set")
- source_config = pagin_config.get_source_config("room")
+ room_token = from_token.room_key
with await self.pagination_lock.read(room_id):
(
@@ -358,7 +354,7 @@ class PaginationHandler:
room_id, user_id, allow_departed_users=True
)
- if source_config.direction == "b":
+ if pagin_config.direction == "b":
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
@@ -377,26 +373,35 @@ class PaginationHandler:
# case "JOIN" would have been returned.
assert member_event_id
- leave_token = await self.store.get_topological_token_for_event(
+ leave_token_str = await self.store.get_topological_token_for_event(
member_event_id
)
- if RoomStreamToken.parse(leave_token).topological < max_topo:
- source_config.from_key = str(leave_token)
+ leave_token = RoomStreamToken.parse(leave_token_str)
+ assert leave_token.topological is not None
+
+ if leave_token.topological < max_topo:
+ from_token = from_token.copy_and_replace(
+ "room_key", leave_token
+ )
await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
+ to_room_key = None
+ if pagin_config.to_token:
+ to_room_key = pagin_config.to_token.room_key
+
events, next_key = await self.store.paginate_room_events(
room_id=room_id,
- from_key=source_config.from_key,
- to_key=source_config.to_key,
- direction=source_config.direction,
- limit=source_config.limit,
+ from_key=from_token.room_key,
+ to_key=to_room_key,
+ direction=pagin_config.direction,
+ limit=pagin_config.limit,
event_filter=event_filter,
)
- next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
+ next_token = from_token.copy_and_replace("room_key", next_key)
if events:
if event_filter:
@@ -409,7 +414,7 @@ class PaginationHandler:
if not events:
return {
"chunk": [],
- "start": pagin_config.from_token.to_string(),
+ "start": from_token.to_string(),
"end": next_token.to_string(),
}
@@ -438,7 +443,7 @@ class PaginationHandler:
events, time_now, as_client_event=as_client_event
)
),
- "start": pagin_config.from_token.to_string(),
+ "start": from_token.to_string(),
"end": next_token.to_string(),
}
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 91a3aec1cc..1000ac95ff 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1108,9 +1108,6 @@ class PresenceEventSource:
def get_current_key(self):
return self.store.get_current_presence_token()
- async def get_pagination_rows(self, user, pagination_config, key):
- return await self.get_new_events(user, from_key=None, include_offline=False)
-
@cached(num_args=2, cache_context=True)
async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 2cc6c2eb68..bdd8e52edd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -142,18 +142,3 @@ class ReceiptEventSource:
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
-
- async def get_pagination_rows(self, user, config, key):
- to_key = int(config.from_key)
-
- if config.to_key:
- from_key = int(config.to_key)
- else:
- from_key = None
-
- room_ids = await self.store.get_rooms_for_user(user.to_string())
- events = await self.store.get_linearized_receipts_for_rooms(
- room_ids, from_key=from_key, to_key=to_key
- )
-
- return (events, to_key)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a29305f655..53d85ab97d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1091,20 +1091,19 @@ class RoomEventSource:
async def get_new_events(
self,
user: UserID,
- from_key: str,
+ from_key: RoomStreamToken,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
# We just ignore the key for now.
to_key = self.get_current_key()
- from_token = RoomStreamToken.parse(from_key)
- if from_token.topological:
+ if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
- from_key = "s%s" % (from_token.stream,)
+ from_key = RoomStreamToken(None, from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
@@ -1133,14 +1132,14 @@ class RoomEventSource:
events[:] = events[:limit]
if events:
- end_key = events[-1].internal_metadata.after
+ end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
else:
end_key = to_key
return (events, end_key)
- def get_current_key(self) -> str:
- return "s%d" % (self.store.get_room_max_stream_ordering(),)
+ def get_current_key(self) -> RoomStreamToken:
+ return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 32b7e323fa..100f335b80 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
from ._base import BaseHandler
@@ -149,17 +149,6 @@ class RoomMemberHandler:
raise NotImplementedError()
@abc.abstractmethod
- async def _user_joined_room(self, target: UserID, room_id: str) -> None:
- """Notifies distributor on master process that the user has joined the
- room.
-
- Args:
- target
- room_id
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
@@ -221,7 +210,6 @@ class RoomMemberHandler:
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- newly_joined = False
if event.membership == Membership.JOIN:
newly_joined = True
if prev_member_event_id:
@@ -246,12 +234,7 @@ class RoomMemberHandler:
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
- if event.membership == Membership.JOIN and newly_joined:
- # Only fire user_joined_room if the user has actually joined the
- # room. Don't bother if the user is just changing their profile
- # info.
- await self._user_joined_room(target, room_id)
- elif event.membership == Membership.LEAVE:
+ if event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
@@ -726,17 +709,7 @@ class RoomMemberHandler:
(EventTypes.Member, event.state_key), None
)
- if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has actually joined the
- # room. Don't bother if the user is just changing their profile
- # info.
- newly_joined = True
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- newly_joined = prev_member_event.membership != Membership.JOIN
- if newly_joined:
- await self._user_joined_room(target_user, room_id)
- elif event.membership == Membership.LEAVE:
+ if event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
@@ -1002,10 +975,9 @@ class RoomMemberHandler:
class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs):
- super(RoomMemberMasterHandler, self).__init__(hs)
+ super().__init__(hs)
self.distributor = hs.get_distributor()
- self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
async def _is_remote_room_too_complex(
@@ -1085,7 +1057,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
- await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@@ -1228,11 +1199,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
return event.event_id, stream_id
- async def _user_joined_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_joined_room
- """
- user_joined_room(self.distributor, target, room_id)
-
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 897338fd54..e7f34737c6 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content,
)
- await self._user_joined_room(user, room_id)
-
return ret["event_id"], ret["stream_id"]
async def remote_reject_invite(
@@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
- async def _user_joined_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_joined_room
- """
- await self._notify_change_client(
- user_id=target.to_string(), room_id=room_id, change="joined"
- )
-
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e2ddb628ff..a615c7c2f0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -378,7 +378,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"):
- typing_key = since_token.typing_key if since_token else "0"
+ typing_key = since_token.typing_key if since_token else 0
room_ids = sync_result_builder.joined_room_ids
@@ -402,7 +402,7 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
- receipt_key = since_token.receipt_key if since_token else "0"
+ receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = await receipt_source.get_new_events(
@@ -533,7 +533,7 @@ class SyncHandler:
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
- room_key = recents[0].internal_metadata.before
+ room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@@ -1310,12 +1310,11 @@ class SyncHandler:
presence_source = self.event_sources.sources["presence"]
since_token = sync_result_builder.since_token
+ presence_key = None
+ include_offline = False
if since_token and not sync_result_builder.full_state:
presence_key = since_token.presence_key
include_offline = True
- else:
- presence_key = None
- include_offline = False
presence, presence_key = await presence_source.get_new_events(
user=user,
@@ -1323,6 +1322,7 @@ class SyncHandler:
is_guest=sync_config.is_guest,
include_offline=include_offline,
)
+ assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace(
"presence_key", presence_key
)
@@ -1485,7 +1485,7 @@ class SyncHandler:
if rooms_changed:
return True
- stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+ stream_id = since_token.room_key.stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
return True
@@ -1751,7 +1751,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
- "room_key", "s%d" % (event.stream_ordering,)
+ "room_key", RoomStreamToken(None, event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index fea774e2e5..becf66dd86 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -29,11 +29,11 @@ def _log_debug_as_f(f, msg, msg_args):
lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
- record = logging.LogRecord(
+ record = logger.makeRecord(
name=name,
level=logging.DEBUG,
- pathname=pathname,
- lineno=lineno,
+ fn=pathname,
+ lno=lineno,
msg=msg,
args=msg_args,
exc_info=None,
diff --git a/synapse/notifier.py b/synapse/notifier.py
index b7f4041306..12cd84b27b 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, StreamToken, UserID
+from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -112,7 +112,9 @@ class _NotifierUserStream:
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
- def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
+ def notify(
+ self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
+ ):
"""Notify any listeners for this user of a new event from an
event source.
Args:
@@ -187,7 +189,7 @@ class Notifier:
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
- ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
+ ) # type: List[Tuple[int, EventBase, Collection[UserID]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -198,6 +200,7 @@ class Notifier:
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
+ self._pusher_pool = hs.get_pusherpool()
self.federation_sender = None
if hs.should_send_federation():
@@ -247,7 +250,7 @@ class Notifier:
event: EventBase,
room_stream_id: int,
max_room_stream_id: int,
- extra_users: Collection[Union[str, UserID]] = [],
+ extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -274,47 +277,68 @@ class Notifier:
"""
pending = self.pending_new_room_events
self.pending_new_room_events = []
+
+ users = set() # type: Set[UserID]
+ rooms = set() # type: Set[str]
+
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
self.pending_new_room_events.append(
(room_stream_id, event, extra_users)
)
else:
- self._on_new_room_event(event, room_stream_id, extra_users)
+ if (
+ event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ ):
+ self._user_joined_room(event.state_key, event.room_id)
+
+ users.update(extra_users)
+ rooms.add(event.room_id)
+
+ if users or rooms:
+ self.on_new_event(
+ "room_key",
+ RoomStreamToken(None, max_room_stream_id),
+ users=users,
+ rooms=rooms,
+ )
+ self._on_updated_room_token(max_room_stream_id)
+
+ def _on_updated_room_token(self, max_room_stream_id: int):
+ """Poke services that might care that the room position has been
+ updated.
+ """
- def _on_new_room_event(
- self,
- event: EventBase,
- room_stream_id: int,
- extra_users: Collection[Union[str, UserID]] = [],
- ):
- """Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
- "notify_app_services", self._notify_app_services, room_stream_id
+ "_notify_app_services", self._notify_app_services, max_room_stream_id
)
- if self.federation_sender:
- self.federation_sender.notify_new_events(room_stream_id)
-
- if event.type == EventTypes.Member and event.membership == Membership.JOIN:
- self._user_joined_room(event.state_key, event.room_id)
-
- self.on_new_event(
- "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
+ run_as_background_process(
+ "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
)
- async def _notify_app_services(self, room_stream_id: int):
+ if self.federation_sender:
+ self.federation_sender.notify_new_events(max_room_stream_id)
+
+ async def _notify_app_services(self, max_room_stream_id: int):
try:
- await self.appservice_handler.notify_interested_services(room_stream_id)
+ await self.appservice_handler.notify_interested_services(max_room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
+ async def _notify_pusher_pool(self, max_room_stream_id: int):
+ try:
+ await self._pusher_pool.on_new_notifications(max_room_stream_id)
+ except Exception:
+ logger.exception("Error pusher pool of event")
+
def on_new_event(
self,
stream_key: str,
- new_token: int,
- users: Collection[Union[str, UserID]] = [],
+ new_token: Union[int, RoomStreamToken],
+ users: Collection[UserID] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise.
@@ -432,8 +456,9 @@ class Notifier:
If explicit_room_id is set, that room will be polled for events only if
it is world readable or the user has joined the room.
"""
- from_token = pagination_config.from_token
- if not from_token:
+ if pagination_config.from_token:
+ from_token = pagination_config.from_token
+ else:
from_token = self.event_sources.get_current_token()
limit = pagination_config.limit
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index b7ea4438e0..28bd8ab748 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -91,7 +91,7 @@ class EmailPusher:
pass
self.timed_call = None
- def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+ def on_new_notifications(self, max_stream_ordering):
if self.max_stream_ordering:
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f21fa9b659..26706bf3e1 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -114,7 +114,7 @@ class HttpPusher:
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+ def on_new_notifications(self, max_stream_ordering):
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering or 0
)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 6c57854018..455a1acb46 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -123,7 +123,7 @@ class Mailer:
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
self.hs.config.public_baseurl
- + "_matrix/client/unstable/password_reset/email/submit_token?%s"
+ + "_synapse/client/password_reset/email/submit_token?%s"
% urllib.parse.urlencode(params)
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3c3262a88c..cc839ffce4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -64,6 +64,12 @@ class PusherPool:
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
+ # Record the last stream ID that we were poked about so we can get
+ # changes since then. We set this to the current max stream ID on
+ # startup as every individual pusher will have checked for changes on
+ # startup.
+ self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
+
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
@@ -178,20 +184,27 @@ class PusherPool:
)
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
- async def on_new_notifications(self, min_stream_id, max_stream_id):
+ async def on_new_notifications(self, max_stream_id: int):
if not self.pushers:
# nothing to do here.
return
+ if max_stream_id < self._last_room_stream_id_seen:
+ # Nothing to do
+ return
+
+ prev_stream_id = self._last_room_stream_id_seen
+ self._last_room_stream_id_seen = max_stream_id
+
try:
users_affected = await self.store.get_push_action_users_in_range(
- min_stream_id, max_stream_id
+ prev_stream_id, max_stream_id
)
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
- p.on_new_notifications(min_stream_id, max_stream_id)
+ p.on_new_notifications(max_stream_id)
except Exception:
logger.exception("Exception in pusher on_new_notifications")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2d995ec456..ff0c67228b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
- "canonicaljson>=1.3.0",
+ "canonicaljson>=1.4.0",
# we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0",
"pynacl>=1.2.1",
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 741329ab5f..08095fdf7d 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
Args:
room_id (str)
user_id (str)
- change (str): Either "joined" or "left"
+ change (str): "left"
"""
- assert change in ("joined", "left")
+ assert change == "left"
return {}
@@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
user = UserID.from_string(user_id)
- if change == "joined":
- user_joined_room(self.distributor, user, room_id)
- elif change == "left":
+ if change == "left":
user_left_room(self.distributor, user, room_id)
else:
raise Exception("Unrecognized change: %r", change)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d6ecf5b327..e82b9e386f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
+from synapse.types import UserID
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -98,7 +99,6 @@ class ReplicationDataHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
@@ -148,14 +148,12 @@ class ReplicationDataHandler:
if event.rejected_reason:
continue
- extra_users = () # type: Tuple[str, ...]
+ extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
+ extra_users = (UserID.from_string(event.state_key),)
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(event, token, max_token, extra_users)
- await self.pusher_pool.on_new_notifications(token, token)
-
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html
new file mode 100644
index 0000000000..def4b5162b
--- /dev/null
+++ b/synapse/res/templates/password_reset_confirmation.html
@@ -0,0 +1,16 @@
+<html>
+<head></head>
+<body>
+<!--Use a hidden form to resubmit the information necessary to reset the password-->
+<form method="post">
+ <input type="hidden" name="sid" value="{{ sid }}">
+ <input type="hidden" name="token" value="{{ token }}">
+ <input type="hidden" name="client_secret" value="{{ client_secret }}">
+
+ <p>You have requested to <strong>reset your Matrix account password</strong>. Click the link below to confirm this action. <br /><br />
+ If you did not mean to do this, please close this page and your password will not be changed.</p>
+ <p><button type="submit">Confirm changing my password</button></p>
+</form>
+</body>
+</html>
+
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 87f927890c..40f5c32db2 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.rest.admin
from synapse.http.server import JsonResource
+from synapse.rest import admin
from synapse.rest.client import versions
from synapse.rest.client.v1 import (
directory,
@@ -123,9 +123,7 @@ class ClientRestResource(JsonResource):
password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
- synapse.rest.admin.register_servlets_for_client_rest_resource(
- hs, client_resource
- )
+ admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e781a3bcf4..ddf8ed5e9c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet):
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
+ if spec["attr"] not in ("enabled", "actions"):
+ # for the sake of potential future expansion, shouldn't report
+ # 404 in the case of an unknown request so check it corresponds to
+ # a known attribute first.
+ raise UnrecognizedRequestError()
+
+ namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+ rule_id = spec["rule_id"]
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return await self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val
+ user_id, namespaced_rule_id, val, is_default_rule
)
elif spec["attr"] == "actions":
actions = val.get("actions")
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3481477731..c6cb9deb2b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -17,6 +17,11 @@
import logging
import random
from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -98,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -144,81 +152,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret
-class PasswordResetSubmitTokenServlet(RestServlet):
- """Handles 3PID validation token submission"""
-
- PATTERNS = client_patterns(
- "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
- )
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- super(PasswordResetSubmitTokenServlet, self).__init__()
- self.hs = hs
- self.auth = hs.get_auth()
- self.config = hs.config
- self.clock = hs.get_clock()
- self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self._failure_email_template = (
- self.config.email_password_reset_template_failure_html
- )
-
- async def on_GET(self, request, medium):
- # We currently only handle threepid token submissions for email
- if medium != "email":
- raise SynapseError(
- 400, "This medium is currently not supported for password resets"
- )
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Password reset emails have been disabled due to lack of an email config"
- )
- raise SynapseError(
- 400, "Email-based password resets are disabled on this server"
- )
-
- sid = parse_string(request, "sid", required=True)
- token = parse_string(request, "token", required=True)
- client_secret = parse_string(request, "client_secret", required=True)
- assert_valid_client_secret(client_secret)
-
- # Attempt to validate a 3PID session
- try:
- # Mark the session as valid
- next_link = await self.store.validate_threepid_session(
- sid, client_secret, token, self.clock.time_msec()
- )
-
- # Perform a 302 redirect if next_link is set
- if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
-
- # Otherwise show the success template
- html = self.config.email_password_reset_template_success_html_content
- status_code = 200
- except ThreepidValidationError as e:
- status_code = e.code
-
- # Show a failure page with a reason
- template_vars = {"failure_reason": e.msg}
- html = self._failure_email_template.render(**template_vars)
-
- respond_with_html(request, status_code, html)
-
-
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
@@ -446,6 +379,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
@@ -517,6 +453,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -603,15 +542,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
# Otherwise show the success template
html = self.config.email_add_threepid_template_success_html_content
@@ -875,6 +809,45 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+ """
+ Raises a SynapseError if a given next_link value is invalid
+
+ next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+ option is either empty or contains a domain that matches the one in the given next_link
+
+ Args:
+ hs: The homeserver object
+ next_link: The next_link value given by the client
+
+ Raises:
+ SynapseError: If the next_link is invalid
+ """
+ valid = True
+
+ # Parse the contents of the URL
+ next_link_parsed = urlparse(next_link)
+
+ # Scheme must not point to the local drive
+ if next_link_parsed.scheme == "file":
+ valid = False
+
+ # If the domain whitelist is set, the domain must be in it
+ if (
+ valid
+ and hs.config.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ ):
+ valid = False
+
+ if not valid:
+ raise SynapseError(
+ 400,
+ "'next_link' domain not included in whitelist, or not http(s)",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
@@ -890,7 +863,6 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d2826374a7..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -80,7 +80,7 @@ class MediaFilePaths:
self, server_name, file_id, width, height, content_type, method
):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths:
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+ # Legacy path that was used to store thumbnails previously.
+ # Should be removed after some time, when most of the thumbnails are stored
+ # using the new path.
+ def remote_media_thumbnail_rel_legacy(
+ self, server_name, file_id, width, height, content_type
+ ):
+ top_level_type, sub_type = content_type.split("/")
+ file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ return os.path.join(
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
+ file_name,
+ )
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9a1b7779f7..69f353d46f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ from .media_storage import MediaStorage
from .preview_url_resource import PreviewUrlResource
from .storage_provider import StorageProviderWrapper
from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
logger = logging.getLogger(__name__)
@@ -460,13 +460,30 @@ class MediaRepository:
return t_byte_source
async def generate_local_exact_thumbnail(
- self, media_id, t_width, t_height, t_method, t_type, url_cache
- ):
+ self,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ url_cache: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+ media_id,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -506,14 +523,36 @@ class MediaRepository:
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def generate_remote_exact_thumbnail(
- self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+ media_id,
+ server_name,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -559,6 +598,9 @@ class MediaRepository:
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def _generate_thumbnails(
self,
server_name: Optional[str],
@@ -590,7 +632,18 @@ class MediaRepository:
FileInfo(server_name, file_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s",
+ media_id,
+ server_name,
+ media_type,
+ e,
+ )
+ return None
+
m_width = thumbnailer.width
m_height = thumbnailer.height
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3a352b5631..5681677fc9 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -147,6 +147,20 @@ class MediaStorage:
if os.path.exists(local_path):
return FileResponder(open(local_path, "rb"))
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return FileResponder(open(legacy_local_path, "rb"))
+
for provider in self.storage_providers:
res = await provider.fetch(path, file_info) # type: Any
if res:
@@ -170,6 +184,20 @@ class MediaStorage:
if os.path.exists(local_path):
return local_path
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return legacy_local_path
+
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _select_or_generate_remote_thumbnail(
self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index d681bf7bf0..457ad6031c 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
import logging
from io import BytesIO
-from PIL import Image as Image
+from PIL import Image
logger = logging.getLogger(__name__)
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
}
+class ThumbnailError(Exception):
+ """An error occurred generating a thumbnail."""
+
+
class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
- self.image = Image.open(input_path)
+ try:
+ self.image = Image.open(input_path)
+ except OSError as e:
+ # If an error occurs opening the image, a thumbnail won't be able to
+ # be generated.
+ raise ThumbnailError from e
+
self.width, self.height = self.image.size
self.transpose_method = None
try:
diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/client/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
new file mode 100644
index 0000000000..9e4fbc0cbd
--- /dev/null
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.errors import ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.server import DirectServeHtmlResource
+from synapse.http.servlet import parse_string
+from synapse.util.stringutils import assert_valid_client_secret
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
+ """Handles 3PID validation token submission
+
+ This resource gets mounted under /_synapse/client/password_reset/email/submit_token
+ """
+
+ isLeaf = 1
+
+ def __init__(self, hs: "HomeServer"):
+ """
+ Args:
+ hs: server
+ """
+ super().__init__()
+
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ self._local_threepid_handling_disabled_due_to_email_config = (
+ hs.config.local_threepid_handling_disabled_due_to_email_config
+ )
+ self._confirmation_email_template = (
+ hs.config.email_password_reset_template_confirmation_html
+ )
+ self._email_password_reset_template_success_html = (
+ hs.config.email_password_reset_template_success_html_content
+ )
+ self._failure_email_template = (
+ hs.config.email_password_reset_template_failure_html
+ )
+
+ # This resource should not be mounted if threepid behaviour is not LOCAL
+ assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+
+ async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
+
+ # Show a confirmation page, just in case someone accidentally clicked this link when
+ # they didn't mean to
+ template_vars = {
+ "sid": sid,
+ "token": token,
+ "client_secret": client_secret,
+ }
+ return (
+ 200,
+ self._confirmation_email_template.render(**template_vars).encode("utf-8"),
+ )
+
+ async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]:
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warning(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ next_link_bytes = next_link.encode("utf-8")
+ request.setHeader("Location", next_link_bytes)
+ return (
+ 302,
+ (
+ b'You are being redirected to <a src="%s">%s</a>.'
+ % (next_link_bytes, next_link_bytes)
+ ),
+ )
+
+ # Otherwise show the success template
+ html_bytes = self._email_password_reset_template_success_html.encode(
+ "utf-8"
+ )
+ status_code = 200
+ except ThreepidValidationError as e:
+ status_code = e.code
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html_bytes = self._failure_email_template.render(**template_vars).encode(
+ "utf-8"
+ )
+
+ return status_code, html_bytes
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8e5d78f6f7..bbff3c8d5b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
- self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ed8a9bffb1..79ec8f119d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -952,7 +952,7 @@ class DatabasePool:
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times.
@@ -981,7 +981,7 @@ class DatabasePool:
key_names: Iterable[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..306fc6947c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
}
async def get_users_whose_devices_changed(
- self, from_key: str, user_ids: Iterable[str]
+ self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
- from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def get_users_whose_signatures_changed(
- self, user_id: str, from_key: str
+ self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
A set of user IDs with updated signatures.
"""
- from_key = int(from_key)
+
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..9cd1403b38 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -213,7 +213,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
- results = []
+ results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@@ -631,7 +631,9 @@ class PersistEventsStore:
)
@classmethod
- def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ def _filter_events_and_contexts_for_duplicates(
+ cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +643,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
- new_events_and_contexts = OrderedDict()
+ new_events_and_contexts = (
+ OrderedDict()
+ ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@@ -655,7 +659,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
- def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ def _update_room_depths_txn(
+ self,
+ txn,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ ):
"""Update min_depth for each room
Args:
@@ -664,7 +673,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
- depth_updates = {}
+ depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1436,7 +1445,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- events_by_room = {}
+ events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 86557d5512..1d76c761a6 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -17,6 +17,10 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
+ "media_repository_drop_index_wo_method"
+)
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -32,6 +36,59 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
where_clause="url_cache IS NOT NULL",
)
+ # The following the updates add the method to the unique constraint of
+ # the thumbnail databases. That fixes an issue, where thumbnails of the
+ # same resolution, but different methods could overwrite one another.
+ # This can happen with custom thumbnail configs or with dynamic thumbnailing.
+ self.db_pool.updates.register_background_index_update(
+ update_name="local_media_repository_thumbnails_method_idx",
+ index_name="local_media_repository_thumbn_media_id_width_height_method_key",
+ table="local_media_repository_thumbnails",
+ columns=[
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ update_name="remote_media_repository_thumbnails_method_idx",
+ index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
+ table="remote_media_cache_thumbnails",
+ columns=[
+ "media_origin",
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+ self._drop_media_index_without_method,
+ )
+
+ async def _drop_media_index_without_method(self, progress, batch_size):
+ def f(txn):
+ txn.execute(
+ "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+ txn.execute(
+ "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+
+ await self.db_pool.runInteraction("drop_media_indices_without_method", f)
+ await self.db_pool.updates._end_background_update(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+ )
+ return 1
+
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ea833829ae..d7a03cbf7d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# room_depth
# state_groups
# state_groups_state
+ # destination_rooms
# we will build a temporary table listing the events so that we don't
# have to keep shovelling the list back and forth across the
@@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# and finally, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
+ "destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 0de802a86b..9790a31998 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -13,11 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
from typing import List, Tuple, Union
+from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
@@ -540,6 +541,25 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
+ # ensure we have a push_rules_enable row
+ # enabledness defaults to true
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT DO NOTHING
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ new_enable_id = self._push_rules_enable_id_gen.get_next()
+ txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
@@ -552,6 +572,12 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ # we don't use simple_delete_one_txn because that would fail if the
+ # user did not have a push_rule_enable row.
+ self.db_pool.simple_delete_txn(
+ txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}
+ )
+
self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -570,10 +596,29 @@ class PushRuleStore(PushRulesWorkerStore):
event_stream_ordering,
)
- async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+ async def set_push_rule_enabled(
+ self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool
+ ) -> None:
+ """
+ Sets the `enabled` state of a push rule.
+
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ enabled: True if the rule is to be enabled, False if it is to be
+ disabled
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+
+ Raises:
+ NotFoundError if the rule does not exist.
+ """
with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
-
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -582,12 +627,47 @@ class PushRuleStore(PushRulesWorkerStore):
user_id,
rule_id,
enabled,
+ is_default_rule,
)
def _set_push_rule_enabled_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
+ is_default_rule,
):
new_id = self._push_rules_enable_id_gen.get_next()
+
+ if not is_default_rule:
+ # first check it exists; we need to lock for key share so that a
+ # transaction that deletes the push rule will conflict with this one.
+ # We also need a push_rule_enable row to exist for every push_rules
+ # row, otherwise it is possible to simultaneously delete a push rule
+ # (that has no _enable row) and enable it, resulting in a dangling
+ # _enable row. To solve this: we either need to use SERIALISABLE or
+ # ensure we always have a push_rule_enable row for every push_rule
+ # row. We chose the latter.
+ for_key_share = "FOR KEY SHARE"
+ if not isinstance(self.database_engine, PostgresEngine):
+ # For key share is not applicable/available on SQLite
+ for_key_share = ""
+ sql = (
+ """
+ SELECT 1 FROM push_rules
+ WHERE user_name = ? AND rule_id = ?
+ %s
+ """
+ % for_key_share
+ )
+ txn.execute(sql, (user_id, rule_id))
+ if txn.fetchone() is None:
+ # needed to set NOT_FOUND code.
+ raise NotFoundError("Push rule does not exist.")
+
self.db_pool.simple_upsert_txn(
txn,
"push_rules_enable",
@@ -606,8 +686,30 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_actions(
- self, user_id, rule_id, actions, is_default_rule
+ self,
+ user_id: str,
+ rule_id: str,
+ actions: List[Union[dict, str]],
+ is_default_rule: bool,
) -> None:
+ """
+ Sets the `actions` state of a push rule.
+
+ Will throw NotFoundError if the rule does not exist; the Code for this
+ is NOT_FOUND.
+
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ actions: A list of actions (each action being a dict or string),
+ e.g. ["notify", {"set_tweak": "highlight", "value": false}]
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+ """
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -629,12 +731,19 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.db_pool.simple_update_one_txn(
- txn,
- "push_rules",
- {"user_name": user_id, "rule_id": rule_id},
- {"actions": actions_json},
- )
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "push_rules",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
+ )
+ except StoreError as serr:
+ if serr.code == 404:
+ # this sets the NOT_FOUND error Code
+ raise NotFoundError("Push rule does not exist")
+ else:
+ raise
self._insert_push_rules_update_txn(
txn,
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..b64926e9c9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -0,0 +1,33 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is the postgres specific migration modifying the table with a background
+ * migration.
+ */
+
+-- add new index that includes method to local media
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('local_media_repository_thumbnails_method_idx', '{}');
+
+-- add new index that includes method to remote media
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+
+-- drop old index
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
new file mode 100644
index 0000000000..1d0c04b53a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
@@ -0,0 +1,44 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is a sqlite specific migration, since sqlite can't modify the unique
+ * constraint of a table without recreating it.
+ */
+
+CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO local_media_repository_thumbnails_new
+ SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length
+ FROM local_media_repository_thumbnails;
+
+DROP TABLE local_media_repository_thumbnails;
+
+ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails;
+
+CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id);
+
+
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO remote_media_cache_thumbnails_new
+ SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id
+ FROM remote_media_cache_thumbnails;
+
+DROP TABLE remote_media_cache_thumbnails;
+
+ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails;
diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
new file mode 100644
index 0000000000..847aebd85e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules.
+ We ignore rules that are server-default rules because they are not defined
+ in the `push_rules` table.
+**/
+
+DELETE FROM push_rules_enable WHERE
+ rule_id NOT LIKE 'global/%/.m.rule.%'
+ AND NOT EXISTS (
+ SELECT 1 FROM push_rules
+ WHERE push_rules.user_name = push_rules_enable.user_name
+ AND push_rules.rule_id = push_rules_enable.rule_id
+ );
diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
new file mode 100644
index 0000000000..ebfbed7925
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
@@ -0,0 +1,42 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This schema delta alters the schema to enable 'catching up' remote homeservers
+-- after there has been a connectivity problem for any reason.
+
+-- This stores, for each (destination, room) pair, the stream_ordering of the
+-- latest event for that destination.
+CREATE TABLE IF NOT EXISTS destination_rooms (
+ -- the destination in question.
+ destination TEXT NOT NULL REFERENCES destinations (destination),
+ -- the ID of the room in question
+ room_id TEXT NOT NULL REFERENCES rooms (room_id),
+ -- the stream_ordering of the event
+ stream_ordering BIGINT NOT NULL,
+ PRIMARY KEY (destination, room_id)
+ -- We don't declare a foreign key on stream_ordering here because that'd mean
+ -- we'd need to either maintain an index (expensive) or do a table scan of
+ -- destination_rooms whenever we delete an event (also potentially expensive).
+ -- In addition to that, a foreign key on stream_ordering would be redundant
+ -- as this row doesn't need to refer to a specific event; if the event gets
+ -- deleted then it doesn't affect the validity of the stream_ordering here.
+);
+
+-- This index is needed to make it so that a deletion of a room (in the rooms
+-- table) can be efficient, as otherwise a table scan would need to be performed
+-- to check that no destination_rooms rows point to the room to be deleted.
+-- Also: it makes it efficient to delete all the entries for a given room ID,
+-- such as when purging a room.
+CREATE INDEX IF NOT EXISTS destination_rooms_room_id
+ ON destination_rooms (room_id);
diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
new file mode 100644
index 0000000000..55f5d0f732
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky
+-- populate_stats_process_rooms_2 background job and restores the functionality under the
+-- original name.
+-- See https://github.com/matrix-org/synapse/issues/8238 for details
+
+DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms';
+UPDATE background_updates SET update_name = 'populate_stats_process_rooms'
+ WHERE update_name = 'populate_stats_process_rooms_2';
diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
new file mode 100644
index 0000000000..a67aa5e500
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This column tracks the stream_ordering of the event that was most recently
+-- successfully transmitted to the destination.
+-- A value of NULL means that we have not sent an event successfully yet
+-- (at least, not since the introduction of this column).
+ALTER TABLE destinations
+ ADD COLUMN last_successful_stream_ordering BIGINT;
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 55a250ef06..30840dbbaa 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore):
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
self.db_pool.updates.register_background_update_handler(
- "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
- )
- self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
@@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
async def _populate_stats_process_rooms(self, progress, batch_size):
- """
- This was a background update which regenerated statistics for rooms.
-
- It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
- job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
- someone upgrading from <v1.0.0, this background task has been turned into a no-op
- so that the potentially expensive task is not run twice.
-
- Further context: https://github.com/matrix-org/synapse/pull/7977
- """
- await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms"
- )
- return 1
-
- async def _populate_stats_process_rooms_2(self, progress, batch_size):
- """
- This is a background update which regenerates statistics for rooms.
-
- It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
- reasoning.
- """
+ """This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore):
return [r for r, in txn]
rooms_to_work_on = await self.db_pool.runInteraction(
- "populate_stats_rooms_2_get_batch", _get_next_batch
+ "populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore):
progress["last_room_id"] = room_id
await self.db_pool.runInteraction(
- "_populate_stats_process_rooms_2",
+ "_populate_stats_process_rooms",
self.db_pool.updates._background_update_progress_txn,
- "populate_stats_process_rooms_2",
+ "populate_stats_process_rooms",
progress,
)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index be6df8a6d1..2e95518752 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
- from_token: Optional[Tuple[int, int]],
- to_token: Optional[Tuple[int, int]],
+ from_token: Optional[Tuple[Optional[int], int]],
+ to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
@@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Dict[str, Tuple[List[EventBase], str]]:
+ ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
- room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+ room_ids = self._events_stream_cache.get_entities_changed(
+ room_ids, from_key.stream
+ )
if not room_ids:
return {}
@@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
- self, room_ids: Collection[str], from_key: str
+ self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
-
- Args:
- room_ids
- from_key: The room_key portion of a StreamToken
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = from_key.stream
return {
room_id
for room_id in room_ids
@@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r.stream_ordering for r in rows)
+ key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: str, to_key: str
+ self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
if from_key == to_key:
return []
@@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[EventBase], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[_EventDictReturn], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -535,8 +531,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- end_token = RoomStreamToken.parse(end_token)
-
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
@@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
- async def get_stream_token_for_event(self, event_id: str) -> str:
+ async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A "s%d" stream token.
+ A stream token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
- return "s%d" % (stream_id,)
+ return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
@@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[_EventDictReturn], str]:
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -989,8 +983,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token,
- to_token=to_token,
+ from_token=from_token.as_tuple(),
+ to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
- return rows, str(next_token)
+ return rows, next_token
async def paginate_room_events(
self,
room_id: str,
- from_key: str,
- to_key: Optional[str] = None,
+ from_key: RoomStreamToken,
+ to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1083,10 +1077,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
- from_key = RoomStreamToken.parse(from_key)
- if to_key:
- to_key = RoomStreamToken.parse(to_key)
-
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 5b31aab700..c0a958252e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,13 +15,14 @@
import logging
from collections import namedtuple
-from typing import Optional, Tuple
+from typing import Iterable, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore):
allow_none=True,
)
- if result and result["retry_last_ts"] > 0:
+ # check we have a row and retry_last_ts is not null or zero
+ # (retry_last_ts can't be negative)
+ if result and result["retry_last_ts"]:
return result
else:
return None
@@ -273,3 +276,98 @@ class TransactionStore(SQLBaseStore):
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
+
+ async def store_destination_rooms_entries(
+ self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ ) -> None:
+ """
+ Updates or creates `destination_rooms` entries in batch for a single event.
+
+ Args:
+ destinations: list of destinations
+ room_id: the room_id of the event
+ stream_ordering: the stream_ordering of the event
+ """
+
+ return await self.db_pool.runInteraction(
+ "store_destination_rooms_entries",
+ self._store_destination_rooms_entries_txn,
+ destinations,
+ room_id,
+ stream_ordering,
+ )
+
+ def _store_destination_rooms_entries_txn(
+ self,
+ txn: LoggingTransaction,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
+ ) -> None:
+
+ # ensure we have a `destinations` row for this destination, as there is
+ # a foreign key constraint.
+ if isinstance(self.database_engine, PostgresEngine):
+ q = """
+ INSERT INTO destinations (destination)
+ VALUES (?)
+ ON CONFLICT DO NOTHING;
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ q = """
+ INSERT OR IGNORE INTO destinations (destination)
+ VALUES (?);
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ txn.execute_batch(q, ((destination,) for destination in destinations))
+
+ rows = [(destination, room_id) for destination in destinations]
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ "destination_rooms",
+ ["destination", "room_id"],
+ rows,
+ ["stream_ordering"],
+ [(stream_ordering,)] * len(rows),
+ )
+
+ async def get_destination_last_successful_stream_ordering(
+ self, destination: str
+ ) -> Optional[int]:
+ """
+ Gets the stream ordering of the PDU most-recently successfully sent
+ to the specified destination, or None if this information has not been
+ tracked yet.
+
+ Args:
+ destination: the destination to query
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "destinations",
+ {"destination": destination},
+ "last_successful_stream_ordering",
+ allow_none=True,
+ desc="get_last_successful_stream_ordering",
+ )
+
+ async def set_destination_last_successful_stream_ordering(
+ self, destination: str, last_successful_stream_ordering: int
+ ) -> None:
+ """
+ Marks that we have successfully sent the PDUs up to and including the
+ one specified.
+
+ Args:
+ destination: the destination we have successfully sent to
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ return await self.db_pool.simple_upsert(
+ "destinations",
+ keyvalues={"destination": destination},
+ values={"last_successful_stream_ordering": last_successful_stream_ordering},
+ desc="set_last_successful_stream_ordering",
+ )
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..d89f6ed128 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -185,6 +185,8 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+
+ assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
@@ -208,7 +210,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
- partitioned = {}
+ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -305,7 +307,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room = {}
+ events_by_room = (
+ {}
+ ) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@@ -436,7 +440,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
- latest_event_ids: List[str],
+ latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@@ -470,7 +474,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
- )
+ ) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ee60e2a718..a7f2dfb850 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
import os
import re
from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
import attr
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+def prepare_database(
+ db_conn: Connection,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str] = ["main", "state"],
+):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
Args:
db_conn:
database_engine:
- config (synapse.config.homeserver.HomeServerConfig|None):
+ config :
application config, or None if we are connecting to an existing
database which we expect to be configured already
- databases (list[str]): The name of the databases that will be used
+ databases: The name of the databases that will be used
with this physical database. Defaults to all databases.
"""
try:
cur = db_conn.cursor()
+ # sqlite does not automatically start transactions for DDL / SELECT statements,
+ # so we start one before running anything. This ensures that any upgrades
+ # are either applied completely, or not at all.
+ #
+ # (psycopg2 automatically starts a transaction as soon as we run any statements
+ # at all, so this is redundant but harmless there.)
+ cur.execute("BEGIN TRANSACTION")
+
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..2a66b3ad4e 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -224,6 +224,10 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # Set of local IDs that we've processed that are larger than the current
+ # position, due to there being smaller unpersisted IDs.
+ self._finished_ids = set() # type: Set[int]
+
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
@@ -348,17 +352,44 @@ class MultiWriterIdGenerator:
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
- current poistion if possible.
+ current position if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
+ self._finished_ids.add(next_id)
+
+ new_cur = None
+
+ if self._unfinished_ids:
+ # If there are unfinished IDs then the new position will be the
+ # largest finished ID less than the minimum unfinished ID.
+
+ finished = set()
+
+ min_unfinshed = min(self._unfinished_ids)
+ for s in self._finished_ids:
+ if s < min_unfinshed:
+ if new_cur is None or new_cur < s:
+ new_cur = s
+ else:
+ finished.add(s)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids = finished
+ else:
+ # There are no unfinished IDs so the new position is simply the
+ # largest finished one.
+ new_cur = max(self._finished_ids)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids.clear()
- # Figure out if its safe to advance the position by checking there
- # aren't any lower allocated IDs that are yet to finish.
- if all(c > next_id for c in self._unfinished_ids):
+ if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
- self._current_positions[self._instance_name] = max(curr, next_id)
+ self._current_positions[self._instance_name] = max(curr, new_cur)
self._add_persisted_position(next_id)
@@ -428,7 +459,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values())
+ min_curr = min(self._current_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index d97dc4d101..0bdf846edf 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -14,9 +14,13 @@
# limitations under the License.
import logging
+from typing import Optional
+
+import attr
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@@ -25,38 +29,22 @@ logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
-class SourcePaginationConfig:
-
- """A configuration object which stores pagination parameters for a
- specific event source."""
-
- def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
- self.from_key = from_key
- self.to_key = to_key
- self.direction = "f" if direction == "f" else "b"
- self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
-
- def __repr__(self):
- return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % (
- self.from_key,
- self.to_key,
- self.direction,
- self.limit,
- )
-
-
+@attr.s(slots=True)
class PaginationConfig:
-
"""A configuration object which stores pagination parameters."""
- def __init__(self, from_token=None, to_token=None, direction="f", limit=None):
- self.from_token = from_token
- self.to_token = to_token
- self.direction = "f" if direction == "f" else "b"
- self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
+ from_token = attr.ib(type=Optional[StreamToken])
+ to_token = attr.ib(type=Optional[StreamToken])
+ direction = attr.ib(type=str)
+ limit = attr.ib(type=Optional[int])
@classmethod
- def from_request(cls, request, raise_invalid_params=True, default_limit=None):
+ def from_request(
+ cls,
+ request: SynapseRequest,
+ raise_invalid_params: bool = True,
+ default_limit: Optional[int] = None,
+ ) -> "PaginationConfig":
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
from_tok = parse_string(request, "from")
@@ -78,8 +66,11 @@ class PaginationConfig:
limit = parse_integer(request, "limit", default=default_limit)
- if limit and limit < 0:
- raise SynapseError(400, "Limit must be 0 or above")
+ if limit:
+ if limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
+
+ limit = min(int(limit), MAX_LIMIT)
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
@@ -87,20 +78,10 @@ class PaginationConfig:
logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.")
- def __repr__(self):
+ def __repr__(self) -> str:
return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
self.from_token,
self.to_token,
self.direction,
self.limit,
)
-
- def get_source_config(self, source_name):
- keyname = "%s_key" % source_name
-
- return SourcePaginationConfig(
- from_key=getattr(self.from_token, keyname),
- to_key=getattr(self.to_token, keyname) if self.to_token else None,
- direction=self.direction,
- limit=self.limit,
- )
diff --git a/synapse/types.py b/synapse/types.py
index f7de48f148..dc09448bdc 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,7 @@ import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -362,22 +362,81 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-class StreamToken(
- namedtuple(
- "Token",
- (
- "room_key",
- "presence_key",
- "typing_key",
- "receipt_key",
- "account_data_key",
- "push_rules_key",
- "to_device_key",
- "device_list_key",
- "groups_key",
- ),
+@attr.s(frozen=True, slots=True)
+class RoomStreamToken:
+ """Tokens are positions between events. The token "s1" comes after event 1.
+
+ s0 s1
+ | |
+ [0] V [1] V [2]
+
+ Tokens can either be a point in the live event stream or a cursor going
+ through historic events.
+
+ When traversing the live event stream events are ordered by when they
+ arrived at the homeserver.
+
+ When traversing historic events the events are ordered by their depth in
+ the event graph "topological_ordering" and then by when they arrived at the
+ homeserver "stream_ordering".
+
+ Live tokens start with an "s" followed by the "stream_ordering" id of the
+ event it comes after. Historic tokens start with a "t" followed by the
+ "topological_ordering" id of the event it comes after, followed by "-",
+ followed by the "stream_ordering" id of the event it comes after.
+ """
+
+ topological = attr.ib(
+ type=Optional[int],
+ validator=attr.validators.optional(attr.validators.instance_of(int)),
)
-):
+ stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+
+ @classmethod
+ def parse(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ if string[0] == "t":
+ parts = string[1:].split("-", 1)
+ return cls(topological=int(parts[0]), stream=int(parts[1]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid token %r" % (string,))
+
+ @classmethod
+ def parse_stream_token(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid token %r" % (string,))
+
+ def as_tuple(self) -> Tuple[Optional[int], int]:
+ return (self.topological, self.stream)
+
+ def __str__(self) -> str:
+ if self.topological is not None:
+ return "t%d-%d" % (self.topological, self.stream)
+ else:
+ return "s%d" % (self.stream,)
+
+
+@attr.s(slots=True, frozen=True)
+class StreamToken:
+ room_key = attr.ib(
+ type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
+ )
+ presence_key = attr.ib(type=int)
+ typing_key = attr.ib(type=int)
+ receipt_key = attr.ib(type=int)
+ account_data_key = attr.ib(type=int)
+ push_rules_key = attr.ib(type=int)
+ to_device_key = attr.ib(type=int)
+ device_list_key = attr.ib(type=int)
+ groups_key = attr.ib(type=int)
+
_SEPARATOR = "_"
START = None # type: StreamToken
@@ -385,24 +444,19 @@ class StreamToken(
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
- while len(keys) < len(cls._fields):
+ while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
- return cls(*keys)
+ return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
- return self._SEPARATOR.join([str(k) for k in self])
+ return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
@property
def room_stream_id(self):
- # TODO(markjh): Awful hack to work around hacks in the presence tests
- # which assume that the keys are integers.
- if type(self.room_key) is int:
- return self.room_key
- else:
- return int(self.room_key[1:].split("-")[-1])
+ return self.room_key.stream
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
@@ -418,7 +472,7 @@ class StreamToken(
or (int(other.groups_key) < int(self.groups_key))
)
- def copy_and_advance(self, key, new_value):
+ def copy_and_advance(self, key, new_value) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
@@ -434,64 +488,11 @@ class StreamToken(
else:
return self
- def copy_and_replace(self, key, new_value):
- return self._replace(**{key: new_value})
-
-
-StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
-
-
-class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
- """Tokens are positions between events. The token "s1" comes after event 1.
-
- s0 s1
- | |
- [0] V [1] V [2]
-
- Tokens can either be a point in the live event stream or a cursor going
- through historic events.
+ def copy_and_replace(self, key, new_value) -> "StreamToken":
+ return attr.evolve(self, **{key: new_value})
- When traversing the live event stream events are ordered by when they
- arrived at the homeserver.
- When traversing historic events the events are ordered by their depth in
- the event graph "topological_ordering" and then by when they arrived at the
- homeserver "stream_ordering".
-
- Live tokens start with an "s" followed by the "stream_ordering" id of the
- event it comes after. Historic tokens start with a "t" followed by the
- "topological_ordering" id of the event it comes after, followed by "-",
- followed by the "stream_ordering" id of the event it comes after.
- """
-
- __slots__ = [] # type: list
-
- @classmethod
- def parse(cls, string):
- try:
- if string[0] == "s":
- return cls(topological=None, stream=int(string[1:]))
- if string[0] == "t":
- parts = string[1:].split("-", 1)
- return cls(topological=int(parts[0]), stream=int(parts[1]))
- except Exception:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
-
- @classmethod
- def parse_stream_token(cls, string):
- try:
- if string[0] == "s":
- return cls(topological=None, stream=int(string[1:]))
- except Exception:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
-
- def __str__(self):
- if self.topological is not None:
- return "t%d-%d" % (self.topological, self.stream)
- else:
- return "s%d" % (self.stream,)
+StreamToken.START = StreamToken.from_string("s0_0")
class ThirdPartyInstanceID(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index a13f11f8d8..60ecc498ab 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import logging
import re
import attr
-from canonicaljson import json
from twisted.internet import defer, task
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index bb57e27beb..67ce9a5f39 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -17,13 +17,25 @@
import collections
import logging
from contextlib import contextmanager
-from typing import Dict, Sequence, Set, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Hashable,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ TypeVar,
+ Union,
+)
import attr
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
+from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from synapse.logging.context import (
@@ -54,7 +66,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"]
- def __init__(self, deferred, consumeErrors=False):
+ def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
@@ -111,25 +123,25 @@ class ObservableDeferred:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self):
+ def observers(self) -> List[defer.Deferred]:
return self._observers
- def has_called(self):
+ def has_called(self) -> bool:
return self._result is not None
- def has_succeeded(self):
+ def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True
- def get_result(self):
+ def get_result(self) -> Any:
return self._result[1]
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self._deferred, name)
- def __setattr__(self, name, value):
+ def __setattr__(self, name: str, value: Any) -> None:
setattr(self._deferred, name, value)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self),
self._result,
@@ -137,18 +149,20 @@ class ObservableDeferred:
)
-def concurrently_execute(func, args, limit):
- """Executes the function with each argument conncurrently while limiting
+def concurrently_execute(
+ func: Callable, args: Iterable[Any], limit: int
+) -> defer.Deferred:
+ """Executes the function with each argument concurrently while limiting
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred or coroutine.
- args (Iterable): List of arguments to pass to func, each invocation of func
+ func: Function to execute, should return a deferred or coroutine.
+ args: List of arguments to pass to func, each invocation of func
gets a single argument.
- limit (int): Maximum number of conccurent executions.
+ limit: Maximum number of conccurent executions.
Returns:
- deferred: Resolved when all function invocations have finished.
+ Deferred[list]: Resolved when all function invocations have finished.
"""
it = iter(args)
@@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit):
).addErrback(unwrapFirstError)
-def yieldable_gather_results(func, iter, *args, **kwargs):
+def yieldable_gather_results(
+ func: Callable, iter: Iterable, *args: Any, **kwargs: Any
+) -> defer.Deferred:
"""Executes the function with each argument concurrently.
Args:
- func (func): Function to execute that returns a Deferred
- iter (iter): An iterable that yields items that get passed as the first
+ func: Function to execute that returns a Deferred
+ iter: An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
+ **kwargs: Keyword arguments to be passed to each call to func
Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if
@@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
).addErrback(unwrapFirstError)
+@attr.s(slots=True)
+class _LinearizerEntry:
+ # The number of things executing.
+ count = attr.ib(type=int)
+ # Deferreds for the things blocked from executing.
+ deferreds = attr.ib(type=collections.OrderedDict)
+
+
class Linearizer:
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.
Example:
- with (yield limiter.queue("test_key")):
+ with await limiter.queue("test_key"):
# do some work.
"""
- def __init__(self, name=None, max_count=1, clock=None):
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ max_count: int = 1,
+ clock: Optional[Clock] = None,
+ ):
"""
Args:
- max_count(int): The maximum number of concurrent accesses
+ max_count: The maximum number of concurrent accesses
"""
if name is None:
- self.name = id(self)
+ self.name = id(self) # type: Union[str, int]
else:
self.name = name
@@ -216,15 +246,10 @@ class Linearizer:
self._clock = clock
self.max_count = max_count
- # key_to_defer is a map from the key to a 2 element list where
- # the first element is the number of things executing, and
- # the second element is an OrderedDict, where the keys are deferreds for the
- # things blocked from executing.
- self.key_to_defer = (
- {}
- ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+ # key_to_defer is a map from the key to a _LinearizerEntry.
+ self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
- def is_queued(self, key) -> bool:
+ def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting
"""
entry = self.key_to_defer.get(key)
@@ -234,25 +259,27 @@ class Linearizer:
# There are waiting deferreds only in the OrderedDict of deferreds is
# non-empty.
- return bool(entry[1])
+ return bool(entry.deferreds)
- def queue(self, key):
+ def queue(self, key: Hashable) -> defer.Deferred:
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
# propagated inside inlineCallbacks until Twisted 18.7)
- entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
+ entry = self.key_to_defer.setdefault(
+ key, _LinearizerEntry(0, collections.OrderedDict())
+ )
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
# When one of the things currently executing finishes it will callback
# this item so that it can continue executing.
- if entry[0] >= self.max_count:
+ if entry.count >= self.max_count:
res = self._await_lock(key)
else:
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key
)
- entry[0] += 1
+ entry.count += 1
res = defer.succeed(None)
# once we successfully get the lock, we need to return a context manager which
@@ -267,15 +294,15 @@ class Linearizer:
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
- entry[0] -= 1
+ entry.count -= 1
- if entry[1]:
- (next_def, _) = entry[1].popitem(last=False)
+ if entry.deferreds:
+ (next_def, _) = entry.deferreds.popitem(last=False)
# we need to run the next thing in the sentinel context.
with PreserveLoggingContext():
next_def.callback(None)
- elif entry[0] == 0:
+ elif entry.count == 0:
# We were the last thing for this key: remove it from the
# map.
del self.key_to_defer[key]
@@ -283,7 +310,7 @@ class Linearizer:
res.addCallback(_ctx_manager)
return res
- def _await_lock(self, key):
+ def _await_lock(self, key: Hashable) -> defer.Deferred:
"""Helper for queue: adds a deferred to the queue
Assumes that we've already checked that we've reached the limit of the number
@@ -298,11 +325,11 @@ class Linearizer:
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
- entry[1][new_defer] = 1
+ entry.deferreds[new_defer] = 1
def cb(_r):
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
- entry[0] += 1
+ entry.count += 1
# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
@@ -331,7 +358,7 @@ class Linearizer:
)
# we just have to take ourselves back out of the queue.
- del entry[1][new_defer]
+ del entry.deferreds[new_defer]
return e
new_defer.addCallbacks(cb, eb)
@@ -419,14 +446,22 @@ class ReadWriteLock:
return _ctx_manager()
-def _cancelled_to_timed_out_error(value, timeout):
+R = TypeVar("R")
+
+
+def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
if isinstance(value, failure.Failure):
value.trap(CancelledError)
raise defer.TimeoutError(timeout, "Deferred")
return value
-def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+def timeout_deferred(
+ deferred: defer.Deferred,
+ timeout: float,
+ reactor: IReactorTime,
+ on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
+) -> defer.Deferred:
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling
@@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
Args:
- deferred (Deferred)
- timeout (float): Timeout in seconds
- reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
- on_timeout_cancel (callable): A callable which is called immediately
+ deferred: The Deferred to potentially timeout.
+ timeout: Timeout in seconds
+ reactor: The twisted reactor to use
+ on_timeout_cancel: A callable which is called immediately
after the deferred times out, and not if this deferred is
otherwise cancelled before the timeout.
@@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
CancelledError Failure into a defer.TimeoutError.
Returns:
- Deferred
+ A new Deferred.
"""
new_d = defer.Deferred()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index a750261e77..f73e95393c 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -16,8 +16,6 @@ import inspect
import logging
from twisted.internet import defer
-from twisted.internet.defer import Deferred, fail, succeed
-from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -29,11 +27,6 @@ def user_left_room(distributor, user, room_id):
distributor.fire("user_left_room", user=user, room_id=room_id)
-# XXX: this is no longer used. We should probably kill it.
-def user_joined_room(distributor, user, room_id):
- distributor.fire("user_joined_room", user=user, room_id=room_id)
-
-
class Distributor:
"""A central dispatch point for loosely-connected pieces of code to
register, observe, and fire signals.
@@ -81,28 +74,6 @@ class Distributor:
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
-def maybeAwaitableDeferred(f, *args, **kw):
- """
- Invoke a function that may or may not return a Deferred or an Awaitable.
-
- This is a modified version of twisted.internet.defer.maybeDeferred.
- """
- try:
- result = f(*args, **kw)
- except Exception:
- return fail(failure.Failure(captureVars=Deferred.debug))
-
- if isinstance(result, Deferred):
- return result
- # Handle the additional case of an awaitable being returned.
- elif inspect.isawaitable(result):
- return defer.ensureDeferred(result)
- elif isinstance(result, failure.Failure):
- return fail(result)
- else:
- return succeed(result)
-
-
class Signal:
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@@ -132,22 +103,17 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
- def do(observer):
- def eb(failure):
+ async def do(observer):
+ try:
+ result = observer(*args, **kwargs)
+ if inspect.isawaitable(result):
+ result = await result
+ return result
+ except Exception as e:
logger.warning(
- "%s signal observer %s failed: %r",
- self.name,
- observer,
- failure,
- exc_info=(
- failure.type,
- failure.value,
- failure.getTracebackObject(),
- ),
+ "%s signal observer %s failed: %r", self.name, observer, e,
)
- return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
-
deferreds = [run_in_background(do, o) for o in self.observers]
return make_deferred_yieldable(
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 0e445e01d7..bf094c9386 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import json
+import json
+
from frozendict import frozendict
@@ -66,5 +67,5 @@ def _handle_frozendict(obj):
# A JSONEncoder which is capable of encoding frozendicts without barfing.
# Additionally reduce the whitespace produced by JSON encoding.
frozendict_json_encoder = json.JSONEncoder(
- default=_handle_frozendict, separators=(",", ":"),
+ allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
)
diff --git a/sytest-blacklist b/sytest-blacklist
index 79b2d4402a..b563448016 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -36,3 +36,11 @@ Inbound federation of state requires event_id as a mandatory paramater
# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
Can upload self-signing keys
+
+# Blacklisted until MSC2753 is implemented
+Local users can peek into world_readable rooms by room ID
+We can't peek into rooms with shared history_visibility
+We can't peek into rooms with invited history_visibility
+We can't peek into rooms with joined history_visibility
+Local users can peek by room alias
+Peeked rooms only turn up in the sync for the device who peeked them
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 3d880c499d..1471cc1a28 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(None)
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
new file mode 100644
index 0000000000..6cdcc378f0
--- /dev/null
+++ b/tests/federation/test_federation_catch_up.py
@@ -0,0 +1,158 @@
+from mock import Mock
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.test_utils import event_injection, make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ state_handler = hs.get_state_handler()
+
+ # This mock is crucial for destination_rooms to be populated.
+ state_handler.get_current_hosts_in_room = Mock(
+ return_value=make_awaitable(["test", "host2"])
+ )
+
+ # whenever send_transaction is called, record the pdu data
+ self.pdus = []
+ self.failed_pdus = []
+ self.is_online = True
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ async def record_transaction(self, txn, json_cb):
+ if self.is_online:
+ data = json_cb()
+ self.pdus.extend(data["pdus"])
+ return {}
+ else:
+ data = json_cb()
+ self.failed_pdus.extend(data["pdus"])
+ raise IOError("Failed to connect because this is a test!")
+
+ def get_destination_room(self, room: str, destination: str = "host2") -> dict:
+ """
+ Gets the destination_rooms entry for a (destination, room_id) pair.
+
+ Args:
+ room: room ID
+ destination: what destination, default is "host2"
+
+ Returns:
+ Dictionary of { event_id: str, stream_ordering: int }
+ """
+ event_id, stream_ordering = self.get_success(
+ self.hs.get_datastore().db_pool.execute(
+ "test:get_destination_rooms",
+ None,
+ """
+ SELECT event_id, stream_ordering
+ FROM destination_rooms dr
+ JOIN events USING (stream_ordering)
+ WHERE dr.destination = ? AND dr.room_id = ?
+ """,
+ destination,
+ room,
+ )
+ )[0]
+ return {"event_id": event_id, "stream_ordering": stream_ordering}
+
+ @override_config({"send_federation": True})
+ def test_catch_up_destination_rooms_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"]
+
+ row_1 = self.get_destination_room(room)
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ row_2 = self.get_destination_room(room)
+
+ # check: events correctly registered in order
+ self.assertEqual(row_1["event_id"], event_id_1)
+ self.assertEqual(row_2["event_id"], event_id_2)
+ self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
+
+ @override_config({"send_federation": True})
+ def test_catch_up_last_successful_stream_ordering_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ # take the remote offline
+ self.is_online = False
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ self.helper.send(room, "wombats!", tok=u1_token)
+ self.pump()
+
+ lsso_1 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+
+ self.assertIsNone(
+ lsso_1,
+ "There should be no last successful stream ordering for an always-offline destination",
+ )
+
+ # bring the remote online
+ self.is_online = True
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ lsso_2 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+ row_2 = self.get_destination_room(room)
+
+ self.assertEqual(
+ self.pdus[0]["content"]["body"],
+ "rabbits!",
+ "Test fault: didn't receive the right PDU",
+ )
+ self.assertEqual(
+ row_2["event_id"],
+ event_id_2,
+ "Test fault: destination_rooms not updated correctly",
+ )
+ self.assertEqual(
+ lsso_2,
+ row_2["stream_ordering"],
+ "Send succeeded but not marked as last_successful_stream_ordering",
+ )
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 5f512ff8bf..917762e6b6 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
- mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c7efd3822d..97877c2e42 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
@@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
yield defer.ensureDeferred(
@@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index eddf5e2498..cb7c0ed51a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
+ return_value=make_awaitable(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index a609f148c0..312c0a0d41 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -54,7 +54,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms_2",
+ "update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
@@ -66,7 +66,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
@@ -219,10 +219,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
- {
- "update_name": "populate_stats_process_rooms_2",
- "progress_json": "{}",
- },
+ {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
@@ -231,7 +228,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
@@ -728,7 +725,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms_2",
+ "update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
@@ -740,7 +737,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7bf15c4ba9..f306a09bfa 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -80,6 +80,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_user_directory_stream_pos",
"get_current_state_deltas",
"get_device_updates_by_remote",
+ "get_room_max_stream_ordering",
]
)
@@ -116,7 +117,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
+ self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
(0, [])
)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 5d41443293..3e5a856584 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory:
self._cert_file = create_test_cert_file(sanlist)
def serverConnectionForTLS(self, tlsProtocol):
- ctx = SSL.Context(SSL.TLSv1_METHOD)
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8b4982ecb1..1d7edee5ba 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
- mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
@@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 160c630235..b8b7758d24 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
new file mode 100644
index 0000000000..081052f6a6
--- /dev/null
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, push_rule, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PushRuleAttributesTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ push_rule.register_servlets,
+ ]
+ hijack_auth = False
+
+ def test_enabled_on_creation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even though a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_on_recreation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even if a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_disable(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is disabled and enabled when we ask for it.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # re-enable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule enabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_404_when_get_non_existent(self):
+ """
+ Tests that `enabled` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_get(self):
+ """
+ Tests that `actions` gives you what you expect on a fresh rule.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
+ )
+
+ def test_actions_put(self):
+ """
+ Tests that PUT on actions updates the value you'd get from GET.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # change the rule actions
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["actions"], ["dont_notify"])
+
+ def test_actions_404_when_get_non_existent(self):
+ """
+ Tests that `actions` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..93f899d861 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import os
import re
from email.parser import Parser
+from typing import Optional
+from urllib.parse import urlencode
import pkg_resources
@@ -27,8 +28,10 @@ from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
+from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -69,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
"""Test basic password reset flow
@@ -250,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
+ # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False)
- self.render(request)
+ request.render(self.submit_token_resource)
+ self.pump()
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+ # password reset confirm button
+
+ # Send arguments as url-encoded form data, matching the template's behaviour
+ form_args = []
+ for key, value_list in request.args.items():
+ for value in value_list:
+ arg = (key, value)
+ form_args.append(arg)
+
+ # Confirm the password reset
+ request, channel = self.make_request(
+ "POST",
+ path,
+ content=urlencode(form_args).encode("utf8"),
+ shorthand=False,
+ content_is_form=True,
+ )
+ request.render(self.submit_token_resource)
+ self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -668,16 +696,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def _request_token(self, email, client_secret):
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link(self):
+ """Tests a valid next_link parameter value with no whitelist (good case)"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/good/site",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_exotic_protocol(self):
+ """Tests using a esoteric protocol as a next_link parameter value.
+ Someone may be hosting a client on IPFS etc.
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_file_uri(self):
+ """Tests next_link parameters cannot be file URI"""
+ # Attempt to use a next_link value that points to the local disk
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="file:///host/path",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+ def test_next_link_domain_whitelist(self):
+ """Tests next_link parameters must fit the whitelist if provided"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/some/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.org/some/also/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://bad.example.org/some/bad/page",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": []})
+ def test_empty_next_link_domain_whitelist(self):
+ """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+ disallowed
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/page",
+ expect_code=400,
+ )
+
+ def _request_token(
+ self,
+ email: str,
+ client_secret: str,
+ next_link: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> str:
+ """Request a validation token to add an email address to a user's account
+
+ Args:
+ email: The email address to validate
+ client_secret: A secret string
+ next_link: A link to redirect the user to after validation
+ expect_code: Expected return code of the call
+
+ Returns:
+ The ID of the new threepid validation session
+ """
+ body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+ if next_link:
+ body["next_link"] = next_link
+
request, channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ "POST", b"account/3pid/email/requestToken", body,
)
self.render(request)
- self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(expect_code, channel.code, channel.result)
- return channel.json_body["sid"]
+ return channel.json_body.get("sid")
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
+ expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
- # smol png
+ # smoll png
(
_TestImage(
unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
None,
),
),
+ # an empty file
+ (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- self._test_thumbnail("crop", self.test_image.expected_cropped)
+ self._test_thumbnail(
+ "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ )
def test_thumbnail_scale(self):
- self._test_thumbnail("scale", self.test_image.expected_scaled)
+ self._test_thumbnail(
+ "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ )
- def _test_thumbnail(self, method, expected_body):
+ def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.pump()
- self.assertEqual(channel.code, 200)
- if expected_body is not None:
+ if expected_found:
+ self.assertEqual(channel.code, 200)
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
+ else:
+ # A 404 with a JSON body.
+ self.assertEqual(channel.code, 404)
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.json_body,
+ {
+ "errcode": "M_NOT_FOUND",
+ "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+ % method,
+ },
)
- else:
- # ensure that the result is at least some valid image
- Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/server.py b/tests/server.py
index 48e45c6c8b..61ec670155 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,6 @@
import json
import logging
-from io import BytesIO
+from io import SEEK_END, BytesIO
import attr
from zope.interface import implementer
@@ -135,6 +135,7 @@ def make_request(
request=SynapseRequest,
shorthand=True,
federation_auth_origin=None,
+ content_is_form=False,
):
"""
Make a web request using the given method and path, feed it the
@@ -150,6 +151,8 @@ def make_request(
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -181,6 +184,8 @@ def make_request(
req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ # Twisted expects to be at the end of the content when parsing the request.
+ req.content.seek(SEEK_END)
req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
@@ -195,7 +200,13 @@ def make_request(
)
if content:
- req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+ if content_is_form:
+ req.requestHeaders.addRawHeader(
+ b"Content-Type", b"application/x-www-form-urlencoded"
+ )
+ else:
+ # Assume the body is JSON
+ req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
req.requestReceived(method, path, b"1.1")
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 973338ea71..6382b19dc3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(1000)
+ return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
@@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_tags_for_room = Mock(
- side_effect=lambda user_id, room_id: make_awaitable({})
- )
+ self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(None)
+ return_value=make_awaitable(None)
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(1000)
- )
+ self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(1000)
+ return_value=make_awaitable(1000)
)
# Call the function multiple times to ensure we only send the notice once
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 370c247e16..755c70db31 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(lots_of_users)
+ return_value=make_awaitable(lots_of_users)
)
self.get_success(
self.store.insert_client_ip(
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index f0a8e32f1e..20636fc400 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -122,6 +122,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+ def test_out_of_order_finish(self):
+ """Test that IDs persisted out of order are correctly handled
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+ ctx3 = self.get_success(id_gen.get_next())
+ ctx4 = self.get_success(id_gen.get_next())
+
+ s1 = ctx1.__enter__()
+ s2 = ctx2.__enter__()
+ s3 = ctx3.__enter__()
+ s4 = ctx4.__enter__()
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+ self.assertEqual(s3, 10)
+ self.assertEqual(s4, 11)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx2.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx1.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ ctx4.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ ctx3.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 11})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
+
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
correctly.
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9870c74883..643072bbaf 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 508aeba078..a298cc0fd3 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,6 +17,7 @@
"""
Utilities for running the unit tests
"""
+from asyncio import Future
from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-async def make_awaitable(result: Any):
- """Create an awaitable that just returns a result."""
- return result
+def make_awaitable(result: Any) -> Awaitable[Any]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future = Future() # type: ignore
+ future.set_result(result)
+ return future
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index fb1ca90336..e93aa84405 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -71,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- await hs.get_storage().persistence.persist_event(event, context)
+ persistence = hs.get_storage().persistence
+ assert persistence is not None
+
+ await persistence.persist_event(event, context)
return event
diff --git a/tests/unittest.py b/tests/unittest.py
index 3cb55a7e96..128dd4e19c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -353,6 +353,7 @@ class HomeserverTestCase(TestCase):
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
+ content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -368,6 +369,8 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -384,6 +387,7 @@ class HomeserverTestCase(TestCase):
request,
shorthand,
federation_auth_origin,
+ content_is_form,
)
def render(self, request):
|