diff --git a/synapse/__init__.py b/synapse/__init__.py
index f70381bc71..f493cbd7d1 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.18.0"
+__version__ = "1.19.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d8190f92ab..7aab764360 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -213,6 +213,7 @@ class Auth(object):
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
+ shadow_banned = user_info["shadow_banned"]
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
@@ -252,7 +253,12 @@ class Auth(object):
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
- user, token_id, is_guest, device_id, app_service=app_service
+ user,
+ token_id,
+ is_guest,
+ shadow_banned,
+ device_id,
+ app_service=app_service,
)
except KeyError:
raise MissingClientTokenError()
@@ -297,6 +303,7 @@ class Auth(object):
dict that includes:
`user` (UserID)
`is_guest` (bool)
+ `shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
@@ -356,6 +363,7 @@ class Auth(object):
ret = {
"user": user,
"is_guest": True,
+ "shadow_banned": False,
"token_id": None,
# all guests get the same device id
"device_id": GUEST_DEVICE_ID,
@@ -365,6 +373,7 @@ class Auth(object):
ret = {
"user": user,
"is_guest": False,
+ "shadow_banned": False,
"token_id": None,
"device_id": None,
}
@@ -488,6 +497,7 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
+ "shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 7393d6cb74..a8937d2595 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -23,7 +23,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
-from synapse.storage.presence import UserPresenceState
+from synapse.api.presence import UserPresenceState
from synapse.types import RoomID, UserID
FILTER_SCHEMA = {
diff --git a/synapse/storage/presence.py b/synapse/api/presence.py
index 18a462f0ee..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/api/presence.py
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd137853b1..1417487427 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,12 +18,16 @@
import argparse
import errno
import os
+import time
+import urllib.parse
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
-from typing import Any, List, MutableMapping, Optional
+from typing import Any, Callable, List, MutableMapping, Optional
import attr
+import jinja2
+import pkg_resources
import yaml
@@ -100,6 +104,11 @@ class Config(object):
def __init__(self, root_config=None):
self.root = root_config
+ # Get the path to the default Synapse template directory
+ self.default_template_dir = pkg_resources.resource_filename(
+ "synapse", "res/templates"
+ )
+
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
@@ -184,6 +193,95 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
+ def read_templates(
+ self, filenames: List[str], custom_template_directory: Optional[str] = None,
+ ) -> List[jinja2.Template]:
+ """Load a list of template files from disk using the given variables.
+
+ This function will attempt to load the given templates from the default Synapse
+ template directory. If `custom_template_directory` is supplied, that directory
+ is tried first.
+
+ Files read are treated as Jinja templates. These templates are not rendered yet.
+
+ Args:
+ filenames: A list of template filenames to read.
+
+ custom_template_directory: A directory to try to look for the templates
+ before using the default Synapse template directory instead.
+
+ Raises:
+ ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+ Returns:
+ A list of jinja2 templates.
+ """
+ templates = []
+ search_directories = [self.default_template_dir]
+
+ # The loader will first look in the custom template directory (if specified) for the
+ # given filename. If it doesn't find it, it will use the default template dir instead
+ if custom_template_directory:
+ # Check that the given template directory exists
+ if not self.path_exists(custom_template_directory):
+ raise ConfigError(
+ "Configured template directory does not exist: %s"
+ % (custom_template_directory,)
+ )
+
+ # Search the custom template directory as well
+ search_directories.insert(0, custom_template_directory)
+
+ loader = jinja2.FileSystemLoader(search_directories)
+ env = jinja2.Environment(loader=loader, autoescape=True)
+
+ # Update the environment with our custom filters
+ env.filters.update(
+ {
+ "format_ts": _format_ts_filter,
+ "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+ }
+ )
+
+ for filename in filenames:
+ # Load the template
+ template = env.get_template(filename)
+ templates.append(template)
+
+ return templates
+
+
+def _format_ts_filter(value: int, format: str):
+ return time.strftime(format, time.localtime(value / 1000))
+
+
+def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
+ """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+ Args:
+ public_baseurl: The public, accessible base URL of the homeserver
+ """
+
+ def mxc_to_http_filter(value, width, height, resize_method="crop"):
+ if value[0:6] != "mxc://":
+ return ""
+
+ server_and_media_id = value[6:]
+ fragment = None
+ if "#" in server_and_media_id:
+ server_and_media_id, fragment = server_and_media_id.split("#", 1)
+ fragment = "#" + fragment
+
+ params = {"width": width, "height": height, "method": resize_method}
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ public_baseurl,
+ server_and_media_id,
+ urllib.parse.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
+
class RootConfig(object):
"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index a63acbdc63..7a796996c0 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -23,7 +23,6 @@ from enum import Enum
from typing import Optional
import attr
-import pkg_resources
from ._base import Config, ConfigError
@@ -98,21 +97,18 @@ class EmailConfig(Config):
if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
+ # A user-configurable template directory
template_dir = email_config.get("template_dir")
- # we need an absolute path, because we change directory after starting (and
- # we don't yet know what auxiliary templates like mail.css we will need).
- # (Note that loading as package_resources with jinja.PackageLoader doesn't
- # work for the same reason.)
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates")
-
- self.email_template_dir = os.path.abspath(template_dir)
+ if isinstance(template_dir, str):
+ # We need an absolute path, because we change directory after starting (and
+ # we don't yet know what auxiliary templates like mail.css we will need).
+ template_dir = os.path.abspath(template_dir)
+ elif template_dir is not None:
+ # If template_dir is something other than a str or None, warn the user
+ raise ConfigError("Config option email.template_dir must be type str")
self.email_enable_notifs = email_config.get("enable_notifs", False)
- account_validity_config = config.get("account_validity") or {}
- account_validity_renewal_enabled = account_validity_config.get("renew_at")
-
self.threepid_behaviour_email = (
# Have Synapse handle the email sending if account_threepid_delegates.email
# is not defined
@@ -166,19 +162,6 @@ class EmailConfig(Config):
email_config.get("validation_token_lifetime", "1h")
)
- if (
- self.email_enable_notifs
- or account_validity_renewal_enabled
- or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
- ):
- # make sure we can import the required deps
- import bleach
- import jinja2
-
- # prevent unused warnings
- jinja2
- bleach
-
if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
missing = []
if not self.email_notif_from:
@@ -196,49 +179,49 @@ class EmailConfig(Config):
# These email templates have placeholders in them, and thus must be
# parsed using a templating engine during a request
- self.email_password_reset_template_html = email_config.get(
+ password_reset_template_html = email_config.get(
"password_reset_template_html", "password_reset.html"
)
- self.email_password_reset_template_text = email_config.get(
+ password_reset_template_text = email_config.get(
"password_reset_template_text", "password_reset.txt"
)
- self.email_registration_template_html = email_config.get(
+ registration_template_html = email_config.get(
"registration_template_html", "registration.html"
)
- self.email_registration_template_text = email_config.get(
+ registration_template_text = email_config.get(
"registration_template_text", "registration.txt"
)
- self.email_add_threepid_template_html = email_config.get(
+ add_threepid_template_html = email_config.get(
"add_threepid_template_html", "add_threepid.html"
)
- self.email_add_threepid_template_text = email_config.get(
+ add_threepid_template_text = email_config.get(
"add_threepid_template_text", "add_threepid.txt"
)
- self.email_password_reset_template_failure_html = email_config.get(
+ password_reset_template_failure_html = email_config.get(
"password_reset_template_failure_html", "password_reset_failure.html"
)
- self.email_registration_template_failure_html = email_config.get(
+ registration_template_failure_html = email_config.get(
"registration_template_failure_html", "registration_failure.html"
)
- self.email_add_threepid_template_failure_html = email_config.get(
+ add_threepid_template_failure_html = email_config.get(
"add_threepid_template_failure_html", "add_threepid_failure.html"
)
# These templates do not support any placeholder variables, so we
# will read them from disk once during setup
- email_password_reset_template_success_html = email_config.get(
+ password_reset_template_success_html = email_config.get(
"password_reset_template_success_html", "password_reset_success.html"
)
- email_registration_template_success_html = email_config.get(
+ registration_template_success_html = email_config.get(
"registration_template_success_html", "registration_success.html"
)
- email_add_threepid_template_success_html = email_config.get(
+ add_threepid_template_success_html = email_config.get(
"add_threepid_template_success_html", "add_threepid_success.html"
)
- # Check templates exist
- for f in [
+ # Read all templates from disk
+ (
self.email_password_reset_template_html,
self.email_password_reset_template_text,
self.email_registration_template_html,
@@ -248,32 +231,36 @@ class EmailConfig(Config):
self.email_password_reset_template_failure_html,
self.email_registration_template_failure_html,
self.email_add_threepid_template_failure_html,
- email_password_reset_template_success_html,
- email_registration_template_success_html,
- email_add_threepid_template_success_html,
- ]:
- p = os.path.join(self.email_template_dir, f)
- if not os.path.isfile(p):
- raise ConfigError("Unable to find template file %s" % (p,))
-
- # Retrieve content of web templates
- filepath = os.path.join(
- self.email_template_dir, email_password_reset_template_success_html
+ password_reset_template_success_html_template,
+ registration_template_success_html_template,
+ add_threepid_template_success_html_template,
+ ) = self.read_templates(
+ [
+ password_reset_template_html,
+ password_reset_template_text,
+ registration_template_html,
+ registration_template_text,
+ add_threepid_template_html,
+ add_threepid_template_text,
+ password_reset_template_failure_html,
+ registration_template_failure_html,
+ add_threepid_template_failure_html,
+ password_reset_template_success_html,
+ registration_template_success_html,
+ add_threepid_template_success_html,
+ ],
+ template_dir,
)
- self.email_password_reset_template_success_html = self.read_file(
- filepath, "email.password_reset_template_success_html"
- )
- filepath = os.path.join(
- self.email_template_dir, email_registration_template_success_html
- )
- self.email_registration_template_success_html_content = self.read_file(
- filepath, "email.registration_template_success_html"
+
+ # Render templates that do not contain any placeholders
+ self.email_password_reset_template_success_html_content = (
+ password_reset_template_success_html_template.render()
)
- filepath = os.path.join(
- self.email_template_dir, email_add_threepid_template_success_html
+ self.email_registration_template_success_html_content = (
+ registration_template_success_html_template.render()
)
- self.email_add_threepid_template_success_html_content = self.read_file(
- filepath, "email.add_threepid_template_success_html"
+ self.email_add_threepid_template_success_html_content = (
+ add_threepid_template_success_html_template.render()
)
if self.email_enable_notifs:
@@ -290,17 +277,19 @@ class EmailConfig(Config):
% (", ".join(missing),)
)
- self.email_notif_template_html = email_config.get(
+ notif_template_html = email_config.get(
"notif_template_html", "notif_mail.html"
)
- self.email_notif_template_text = email_config.get(
+ notif_template_text = email_config.get(
"notif_template_text", "notif_mail.txt"
)
- for f in self.email_notif_template_text, self.email_notif_template_html:
- p = os.path.join(self.email_template_dir, f)
- if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p,))
+ (
+ self.email_notif_template_html,
+ self.email_notif_template_text,
+ ) = self.read_templates(
+ [notif_template_html, notif_template_text], template_dir,
+ )
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
@@ -309,18 +298,20 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
- if account_validity_renewal_enabled:
- self.email_expiry_template_html = email_config.get(
+ if self.account_validity.renew_by_email_enabled:
+ expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
- self.email_expiry_template_text = email_config.get(
+ expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt"
)
- for f in self.email_expiry_template_text, self.email_expiry_template_html:
- p = os.path.join(self.email_template_dir, f)
- if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p,))
+ (
+ self.account_validity_template_html,
+ self.account_validity_template_text,
+ ) = self.read_templates(
+ [expiry_template_html, expiry_template_text], template_dir,
+ )
subjects_config = email_config.get("subjects", {})
subjects = {}
@@ -400,9 +391,7 @@ class EmailConfig(Config):
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # Do not uncomment this setting unless you want to customise the templates.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 9277b5f342..036f8c0e90 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -18,8 +18,6 @@ import logging
from typing import Any, List
import attr
-import jinja2
-import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@@ -171,15 +169,9 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
- template_dir = saml2_config.get("template_dir")
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
-
- loader = jinja2.FileSystemLoader(template_dir)
- # enable auto-escape here, to having to remember to escape manually in the
- # template
- env = jinja2.Environment(loader=loader, autoescape=True)
- self.saml2_error_html_template = env.get_template("saml_error.html")
+ self.saml2_error_html_template = self.read_templates(
+ ["saml_error.html"], saml2_config.get("template_dir")
+ )
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 9f15ed109e..ed66f3eba1 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -26,7 +26,6 @@ import yaml
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
-from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -508,8 +507,6 @@ class ServerConfig(Config):
)
)
- _check_resource_config(self.listeners)
-
self.cleanup_extremities_with_dummy_events = config.get(
"cleanup_extremities_with_dummy_events", True
)
@@ -1133,20 +1130,3 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
-
-
-def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
- resource_names = {
- res_name
- for listener in listeners
- if listener.http_options
- for res in listener.http_options.resources
- for res_name in res.names
- }
-
- for resource in resource_names:
- if resource == "consent":
- try:
- check_requirements("resources.consent")
- except DependencyException as e:
- raise ConfigError(e.message)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 73b7296399..4427676167 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,11 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
from typing import Any, Dict
-import pkg_resources
-
from ._base import Config
@@ -29,22 +26,32 @@ class SSOConfig(Config):
def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
- # Pick a template directory in order of:
- # * The sso-specific template_dir
- # * /path/to/synapse/install/res/templates
+ # The sso-specific template_dir
template_dir = sso_config.get("template_dir")
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
- self.sso_template_dir = template_dir
- self.sso_account_deactivated_template = self.read_file(
- os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
- "sso_account_deactivated_template",
+ # Read templates from disk
+ (
+ self.sso_redirect_confirm_template,
+ self.sso_auth_confirm_template,
+ self.sso_error_template,
+ sso_account_deactivated_template,
+ sso_auth_success_template,
+ ) = self.read_templates(
+ [
+ "sso_redirect_confirm.html",
+ "sso_auth_confirm.html",
+ "sso_error.html",
+ "sso_account_deactivated.html",
+ "sso_auth_success.html",
+ ],
+ template_dir,
)
- self.sso_auth_success_template = self.read_file(
- os.path.join(self.sso_template_dir, "sso_auth_success.html"),
- "sso_auth_success_template",
+
+ # These templates have no placeholders, so render them here
+ self.sso_account_deactivated_template = (
+ sso_account_deactivated_template.render()
)
+ self.sso_auth_success_template = sso_auth_success_template.render()
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 2b0ab2dcbf..4d65d4aeea 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
from twisted.internet import defer
+from synapse.api.presence import UserPresenceState
from synapse.metrics import LaterGauge
-from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
from .units import Edu
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 86a611c49c..c6b67c2dd3 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
import synapse
import synapse.metrics
+from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
@@ -39,7 +40,6 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 2fb1782f33..374c5610b4 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -25,12 +25,12 @@ from synapse.api.errors import (
HttpResponseException,
RequestSendFailed,
)
+from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -352,6 +352,28 @@ class PerDestinationQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
+
+ if e.retry_interval > 60 * 60 * 1000:
+ # we won't retry for another hour!
+ # (this suggests a significant outage)
+ # We drop pending PDUs and EDUs because otherwise they will
+ # rack up indefinitely.
+ # Note that:
+ # - the EDUs that are being dropped here are those that we can
+ # afford to drop (specifically, only typing notifications,
+ # read receipts and presence updates are being dropped here)
+ # - Other EDUs such as to_device messages are queued with a
+ # different mechanism
+ # - this is all volatile state that would be lost if the
+ # federation sender restarted anyway
+
+ # dropping read receipts is a bit sad but should be solved
+ # through another mechanism, because this is all volatile!
+ self._pending_pdus = []
+ self._pending_edus = []
+ self._pending_edus_keyed = {}
+ self._pending_presence = {}
+ self._pending_rrs = {}
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 590135d19c..b865bf5b48 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
-try:
- from synapse.push.mailer import load_jinja2_templates
-except ImportError:
- load_jinja2_templates = None
-
logger = logging.getLogger(__name__)
@@ -47,9 +42,11 @@ class AccountValidityHandler(object):
if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
- and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled.
+ self._template_html = self.config.account_validity_template_html
+ self._template_text = self.config.account_validity_template_text
+
try:
app_name = self.hs.config.email_app_name
@@ -65,17 +62,6 @@ class AccountValidityHandler(object):
self._raw_from = email.utils.parseaddr(self._from_string)[1]
- self._template_html, self._template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_expiry_template_html,
- self.config.email_expiry_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
-
# Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c24e7bafe0..68d6870e40 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
-from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.threepids import canonicalise_email
@@ -132,18 +131,17 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
- self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
- )[0]
+ self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
- self._sso_auth_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_auth_confirm.html"],
- )[0]
+ self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
+
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 73e787f2f7..bd468611ae 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -893,9 +893,7 @@ class EventCreationHandler(object):
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ await self.store.remove_push_actions_from_staging(event.event_id)
raise
async def _validate_canonical_alias(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index fa5ee5de8f..87d28a7ae9 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -38,7 +38,6 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
@@ -123,9 +122,7 @@ class OidcHandler:
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_error.html"]
- )[0]
+ self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5387b3724f..24e1940ee5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -33,13 +33,13 @@ from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
-from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c94209ab3d..999bc6efb5 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
+ shadow_banned=False,
):
"""Registers a new client on the server.
@@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
+ shadow_banned (bool): Shadow-ban the created user.
Returns:
str: user_id
Raises:
@@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
+ shadow_banned=shadow_banned,
)
if self.hs.config.user_directory_search_all_users:
@@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler):
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
address=address,
+ shadow_banned=shadow_banned,
)
# Successfully registered
@@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler):
admin=False,
user_type=None,
address=None,
+ shadow_banned=False,
):
"""Register user in the datastore.
@@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration.
+ shadow_banned (bool): Whether to shadow-ban the user
Returns:
Awaitable
@@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
+ shadow_banned=shadow_banned,
)
else:
return self.store.register_user(
@@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
+ shadow_banned=shadow_banned,
)
async def register_device(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a8545255b1..442cca28e6 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -22,7 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
-from typing import Awaitable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -32,11 +32,14 @@ from synapse.api.constants import (
RoomEncryptionAlgorithms,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
+ JsonDict,
Requester,
RoomAlias,
RoomID,
@@ -53,6 +56,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -61,7 +67,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@@ -92,7 +98,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
- }
+ } # type: Dict[str, Dict[str, Any]]
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -215,6 +221,9 @@ class RoomCreationHandler(BaseHandler):
old_room_state = await tombstone_context.get_current_state_ids()
+ # We know the tombstone event isn't an outlier so it has current state.
+ assert old_room_state is not None
+
# update any aliases
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
@@ -528,17 +537,21 @@ class RoomCreationHandler(BaseHandler):
logger.error("Unable to send updated alias events in new room: %s", e)
async def create_room(
- self, requester, config, ratelimit=True, creator_join_profile=None
+ self,
+ requester: Requester,
+ config: JsonDict,
+ ratelimit: bool = True,
+ creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
- requester (synapse.types.Requester):
+ requester:
The user who requested the room creation.
- config (dict) : A dict of configuration options.
- ratelimit (bool): set to False to disable the rate limiter
+ config : A dict of configuration options.
+ ratelimit: set to False to disable the rate limiter
- creator_join_profile (dict|None):
+ creator_join_profile:
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
@@ -601,6 +614,7 @@ class RoomCreationHandler(BaseHandler):
Codes.UNSUPPORTED_ROOM_VERSION,
)
+ room_alias = None
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -611,8 +625,6 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
- else:
- room_alias = None
invite_list = config.get("invite", [])
for i in invite_list:
@@ -771,23 +783,30 @@ class RoomCreationHandler(BaseHandler):
async def _send_events_for_new_room(
self,
- creator, # A Requester object.
- room_id,
- preset_config,
- invite_list,
- initial_state,
- creation_content,
- room_alias=None,
- power_level_content_override=None, # Doesn't apply when initial state has power level state event content
- creator_join_profile=None,
+ creator: Requester,
+ room_id: str,
+ preset_config: str,
+ invite_list: List[str],
+ initial_state: StateMap,
+ creation_content: JsonDict,
+ room_alias: Optional[RoomAlias] = None,
+ power_level_content_override: Optional[JsonDict] = None,
+ creator_join_profile: Optional[JsonDict] = None,
) -> int:
"""Sends the initial events into a new room.
+ `power_level_content_override` doesn't apply when initial state has
+ power level state event content.
+
Returns:
The stream_id of the last event persisted.
"""
- def create(etype, content, **kwargs):
+ creator_id = creator.user.to_string()
+
+ event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
+
+ def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -795,7 +814,7 @@ class RoomCreationHandler(BaseHandler):
return e
- async def send(etype, content, **kwargs) -> int:
+ async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
(
@@ -808,10 +827,6 @@ class RoomCreationHandler(BaseHandler):
config = self._presets_dict[preset_config]
- creator_id = creator.user.to_string()
-
- event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
-
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -852,7 +867,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
- }
+ } # type: JsonDict
if config["original_invitees_have_ops"]:
for invitee in invite_list:
@@ -906,7 +921,7 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
- self, creator_id: str, is_public: str, room_version: RoomVersion,
+ self, creator_id: str, is_public: bool, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -930,23 +945,30 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(
+ self,
+ user: UserID,
+ room_id: str,
+ event_id: str,
+ limit: int,
+ event_filter: Optional[Filter],
+ ) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user (UserID)
- room_id (str)
- event_id (str)
- limit (int): The maximum number of events to return in total
+ user
+ room_id
+ event_id
+ limit: The maximum number of events to return in total
(excluding state).
- event_filter (Filter|None): the filter to apply to the events returned
+ event_filter: the filter to apply to the events returned
(excluding the target event_id)
Returns:
@@ -1033,12 +1055,18 @@ class RoomContextHandler(object):
class RoomEventSource(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def get_new_events(
- self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
- ):
+ self,
+ user: UserID,
+ from_key: str,
+ limit: int,
+ room_ids: List[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
+ ) -> Tuple[List[EventBase], str]:
# We just ignore the key for now.
to_key = self.get_current_key()
@@ -1096,7 +1124,7 @@ class RoomShutdownHandler(object):
)
DEFAULT_ROOM_NAME = "Content Violation Notification"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.room_member_handler = hs.get_room_member_handler()
self._room_creation_handler = hs.get_room_creation_handler()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index af117fddf9..c38e037281 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -16,8 +16,7 @@
import email.mime.multipart
import email.utils
import logging
-import time
-import urllib
+import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
@@ -640,72 +639,3 @@ def string_ordinal_total(s):
for c in s:
tot += ord(c)
return tot
-
-
-def format_ts_filter(value, format):
- return time.strftime(format, time.localtime(value / 1000))
-
-
-def load_jinja2_templates(
- template_dir,
- template_filenames,
- apply_format_ts_filter=False,
- apply_mxc_to_http_filter=False,
- public_baseurl=None,
-):
- """Loads and returns one or more jinja2 templates and applies optional filters
-
- Args:
- template_dir (str): The directory where templates are stored
- template_filenames (list[str]): A list of template filenames
- apply_format_ts_filter (bool): Whether to apply a template filter that formats
- timestamps
- apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
- mxc urls to http urls
- public_baseurl (str|None): The public baseurl of the server. Required for
- apply_mxc_to_http_filter to be enabled
-
- Returns:
- A list of jinja2 templates corresponding to the given list of filenames,
- with order preserved
- """
- logger.info(
- "loading email templates %s from '%s'", template_filenames, template_dir
- )
- loader = jinja2.FileSystemLoader(template_dir)
- env = jinja2.Environment(loader=loader)
-
- if apply_format_ts_filter:
- env.filters["format_ts"] = format_ts_filter
-
- if apply_mxc_to_http_filter and public_baseurl:
- env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
-
- templates = []
- for template_filename in template_filenames:
- template = env.get_template(template_filename)
- templates.append(template)
-
- return templates
-
-
-def _create_mxc_to_http_filter(public_baseurl):
- def mxc_to_http_filter(value, width, height, resize_method="crop"):
- if value[0:6] != "mxc://":
- return ""
-
- serverAndMediaId = value[6:]
- fragment = None
- if "#" in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
- fragment = "#" + fragment
-
- params = {"width": width, "height": height, "method": resize_method}
- return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- public_baseurl,
- serverAndMediaId,
- urllib.parse.urlencode(params),
- fragment or "",
- )
-
- return mxc_to_http_filter
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 8ad0bf5936..f626797133 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -15,22 +15,13 @@
import logging
+from synapse.push.emailpusher import EmailPusher
+from synapse.push.mailer import Mailer
+
from .httppusher import HttpPusher
logger = logging.getLogger(__name__)
-# We try importing this if we can (it will fail if we don't
-# have the optional email dependencies installed). We don't
-# yet have the config to know if we need the email pusher,
-# but importing this after daemonizing seems to fail
-# (even though a simple test of importing from a daemonized
-# process works fine)
-try:
- from synapse.push.emailpusher import EmailPusher
- from synapse.push.mailer import Mailer, load_jinja2_templates
-except Exception:
- pass
-
class PusherFactory(object):
def __init__(self, hs):
@@ -43,16 +34,8 @@ class PusherFactory(object):
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
- self.notif_template_html, self.notif_template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_notif_template_html,
- self.config.email_notif_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
+ self._notif_template_html = hs.config.email_notif_template_html
+ self._notif_template_text = hs.config.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher
@@ -73,8 +56,8 @@ class PusherFactory(object):
mailer = Mailer(
hs=self.hs,
app_name=app_name,
- template_html=self.notif_template_html,
- template_text=self.notif_template_text,
+ template_html=self._notif_template_html,
+ template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index e5f22fb858..3250d41dde 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
- # ConsentResource uses select_autoescape, which arrived in jinja 2.9
- "resources.consent": ["Jinja2>=2.9"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index ce9420aa69..a02b27474d 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
+ shadow_banned,
):
"""
Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
+ shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
+ "shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
+ shadow_banned=content["shadow_banned"],
)
return 200, {}
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 00831879f3..e2df638cc5 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.api.errors import (
NotFoundError,
StoreError,
@@ -163,7 +162,7 @@ class PushRuleRestServlet(RestServlet):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
- def set_rule_attr(self, user_id, spec, val):
+ async def set_rule_attr(self, user_id, spec, val):
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -173,7 +172,9 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+ return await self.store.set_push_rule_enabled(
+ user_id, namespaced_rule_id, val
+ )
elif spec["attr"] == "actions":
actions = val.get("actions")
_check_actions(actions)
@@ -188,7 +189,7 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- return self.store.set_push_rule_actions(
+ return await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index fead85074b..203e76b9f2 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -32,7 +32,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_password_reset_template_html,
- self.config.email_password_reset_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_password_reset_template_html,
+ template_text=self.config.email_password_reset_template_text,
)
async def on_POST(self, request):
@@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_password_reset_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_password_reset_template_failure_html
)
async def on_GET(self, request, medium):
@@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_password_reset_template_success_html
+ html = self.config.email_password_reset_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
@@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_add_threepid_template_html,
- self.config.email_add_threepid_template_text,
- ],
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_add_threepid_template_html,
+ template_text=self.config.email_add_threepid_template_text,
)
async def on_POST(self, request):
@@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_add_threepid_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_add_threepid_template_failure_html
)
async def on_GET(self, request):
@@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index f808175698..7290fd0756 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -44,7 +44,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- from synapse.push.mailer import Mailer, load_jinja2_templates
-
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_registration_template_html,
- self.config.email_registration_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_registration_template_html,
+ template_text=self.config.email_registration_template_text,
)
async def on_POST(self, request):
@@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
- )
-
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_registration_template_failure_html
)
async def on_GET(self, request, medium):
@@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
respond_with_html(request, status_code, html)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index f43463df53..90a1f9e8b1 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -18,8 +18,6 @@ from typing import Optional
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
update_name (str): Name of update
"""
- @defer.inlineCallbacks
- def noop_update(progress, batch_size):
- yield self._end_background_update(update_name)
+ async def noop_update(progress, batch_size):
+ await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, noop_update)
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
else:
runner = create_index_sqlite
- @defer.inlineCallbacks
- def updater(progress, batch_size):
+ async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.db_pool.runWithConnection(runner)
- yield self._end_background_update(update_name)
+ await self.db_pool.runWithConnection(runner)
+ await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4ada6f5563..8a9e06efcf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -332,8 +332,7 @@ class DatabasePool(object):
"""
return self._db_pool.running
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
@@ -342,7 +341,7 @@ class DatabasePool(object):
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = yield self.simple_select_list(
+ updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -614,8 +613,7 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- @defer.inlineCallbacks
- def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -631,7 +629,7 @@ class DatabasePool(object):
`or_ignore` is True
"""
try:
- yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -684,8 +682,7 @@ class DatabasePool(object):
txn.executemany(sql, vals)
- @defer.inlineCallbacks
- def simple_upsert(
+ async def simple_upsert(
self,
table,
keyvalues,
@@ -714,14 +711,14 @@ class DatabasePool(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
+ None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
- result = yield self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@@ -730,7 +727,6 @@ class DatabasePool(object):
insertion_values,
lock=lock,
)
- return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -1121,8 +1117,7 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
- @defer.inlineCallbacks
- def simple_select_many_batch(
+ async def simple_select_many_batch(
self,
table,
column,
@@ -1156,7 +1151,7 @@ class DatabasePool(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
- rows = yield self.runInteraction(
+ rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..02568a2391 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
- A Deferred which resolves when the state was set successfully.
+ An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..9a786e2929 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
- inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..431bd76693 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- def get_oldest_events_in_room(self, room_id):
- return self.db_pool.runInteraction(
- "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
- )
-
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
@@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
- def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.db_pool.simple_select_onecol_txn(
- txn,
- table="event_backward_extremities",
- keyvalues={"room_id": room_id},
- retcol="event_id",
- )
-
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..e8834b2162 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
- @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
- def get_unread_event_push_actions_by_room_for_user(
+ @cached(num_args=3, tree=True, max_entries=5000)
+ async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..b90e6de2d5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @defer.inlineCallbacks
- def _persist_events_and_state_updates(
+ async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- ):
+ ) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -136,7 +133,7 @@ class PersistEventsStore:
backfilled
Returns:
- Deferred: resolves when the events have been persisted
+ Resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -168,7 +165,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
- @defer.inlineCallbacks
- def _get_events_which_are_prevs(self, event_ids):
+ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
- event_ids (Iterable[str]): event ids to filter
+ event_ids: event ids to filter
Returns:
- Deferred[List[str]]: filtered event ids
+ Filtered event ids
"""
results = []
@@ -240,14 +236,13 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
- @defer.inlineCallbacks
- def _get_prevs_before_rejected(self, event_ids):
+ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
- event_ids (Iterable[str]): Events to find prev events for. Note
- that these must have already been persisted.
+ event_ids: Events to find prev events for. Note that these must have
+ already been persisted.
Returns:
- Deferred[set[str]]
+ The previous events.
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
- @defer.inlineCallbacks
- def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
- num_handled = yield self.db_pool.runInteraction(
+ num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
- @defer.inlineCallbacks
- def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- count = yield self.db_pool.runInteraction(
+ count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
- yield self.db_pool.updates._end_background_update("redactions_received_ts")
+ await self.db_pool.updates._end_background_update("redactions_received_ts")
return count
- @defer.inlineCallbacks
- def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
- yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+ await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
- @defer.inlineCallbacks
- def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows
- num_rows = yield self.db_pool.runInteraction(
+ num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
- yield self.db_pool.updates._end_background_update("event_store_labels")
+ await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..8c63a0dc4d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
-
- return self.db_pool.runInteraction(
- "get_approximate_received_ts", _get_approximate_received_ts_txn
- )
-
@defer.inlineCallbacks
def get_event(
self,
@@ -883,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.db_pool.simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
+ rows = yield defer.ensureDeferred(
+ self.db_pool.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
)
return {r["event_id"] for r in rows}
@@ -923,36 +889,6 @@ class EventsWorkerStore(SQLBaseStore):
)
return results
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.db_pool.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn,
- room_id,
- )
-
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -1222,97 +1158,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
- else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
-
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
- )
-
- return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1357,14 +1202,3 @@ class EventsWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
-
-
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..4e3ec02d14 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,8 +15,8 @@
from typing import List, Tuple
+from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
- def get_presence_for_users(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_presence_for_users(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..c2289a9557 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -32,7 +32,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_for_user(self, user_id):
+ rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -133,17 +133,15 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+ enabled_map = await self.get_push_rules_enabled_for_user(user_id)
use_new_defaults = user_id in self._users_new_default_push_rules
- rules = _load_rules(rows, enabled_map, use_new_defaults)
+ return _load_rules(rows, enabled_map, use_new_defaults)
- return rules
-
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_enabled_for_user(self, user_id):
+ results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -170,18 +168,15 @@ class PushRulesWorkerStore(
)
@cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
- def bulk_get_push_rules(self, user_ids):
+ async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+ enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
return results
- @defer.inlineCallbacks
- def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
- yield self.add_push_rule(
+ await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
- @defer.inlineCallbacks
- def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
+ user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
- yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
- inlineCallbacks=True,
)
- def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
+ async def add_push_rule(
self,
user_id,
rule_id,
@@ -342,13 +335,13 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
- ):
+ ) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -362,7 +355,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -546,16 +539,15 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
+ async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
+ user_id: The matrix ID of the push rule owner
+ rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -569,18 +561,17 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
+ async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -611,8 +602,9 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ async def set_push_rule_actions(
+ self, user_id, rule_id, actions, is_default_rule
+ ) -> None:
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -653,7 +645,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..1126fd0751 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
- dataJson = r["data"]
+ data_json = r["data"]
try:
- r["data"] = db_to_json(dataJson)
+ r["data"] = db_to_json(data_json)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
- dataJson,
+ data_json,
e.args[0],
)
continue
yield r
- @defer.inlineCallbacks
- def user_has_pusher(self, user_id):
- ret = yield self.db_pool.simple_select_one_onecol(
+ async def user_has_pusher(self, user_id):
+ ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
- @defer.inlineCallbacks
- def get_pushers_by(self, keyvalues):
- ret = yield self.db_pool.simple_select_list(
+ async def get_pushers_by(self, keyvalues):
+ ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- @defer.inlineCallbacks
- def get_all_pushers(self):
+ async def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
- return rows
+ return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
- @cachedInlineCallbacks(num_args=1, max_entries=15000)
- def get_if_user_has_pusher(self, user_id):
+ @cached(num_args=1, max_entries=15000)
+ async def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- def get_if_users_have_pushers(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_if_users_have_pushers(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
return result
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
+ async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self.db_pool.simple_update_one(
+ ) -> None:
+ await self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
+ async def update_pusher_last_stream_ordering_and_success(
+ self,
+ app_id: str,
+ pushkey: str,
+ user_id: str,
+ last_stream_ordering: int,
+ last_success: int,
+ ) -> bool:
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
- app_id (str)
- pushkey (str)
- last_stream_ordering (int)
- last_success (int)
+ app_id
+ pushkey
+ user_id
+ last_stream_ordering
+ last_success
Returns:
- Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.db_pool.simple_update(
+ updated = await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db_pool.simple_update(
+ async def update_pusher_failing_since(
+ self, app_id, pushkey, user_id, failing_since
+ ) -> None:
+ await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db_pool.simple_select_list(
+ async def get_throttle_params_by_room(self, pusher_id):
+ res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
+ async def set_throttle_params(self, pusher_id, room_id, params) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_pusher(
+ async def add_pusher(
self,
user_id,
access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
- ):
+ ) -> None:
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
- @defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+ async def delete_pusher_by_app_id_pushkey_user_id(
+ self, app_id, pushkey, user_id
+ ) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
@@ -351,6 +345,6 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..19ad1c056f 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
- @cachedInlineCallbacks()
- def get_users_with_read_receipts_in_room(self, room_id):
- receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ @cached()
+ async def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedInlineCallbacks(num_args=2)
- def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db_pool.simple_select_list(
+ @cached(num_args=2)
+ async def get_receipts_for_user(self, user_id, receipt_type):
+ rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- @defer.inlineCallbacks
- def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db_pool.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def get_linearized_receipts_for_rooms(
+ self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_ids (list): List of room_ids.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_id: List of room_ids.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- list: A list of receipts.
+ A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
- room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
- results = yield self._get_linearized_receipts_for_rooms(
+ results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
- def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ async def get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
- room_ids (str): The room id.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_ids: The room id.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- Deferred[list]: A list of receipts.
+ A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
- defer.succeed([])
+ return []
- return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+ return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cachedInlineCallbacks(num_args=3, tree=True)
- def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ @cached(num_args=3, tree=True)
+ async def _get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+ rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
- inlineCallbacks=True,
)
- def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db_pool.runInteraction(
+ txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
- self, room_id, receipt_type, user_id
+ self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
- @defer.inlineCallbacks
- def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ async def insert_receipt(
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: dict,
+ ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
- return
+ return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -507,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db_pool.runInteraction(
+ linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.db_pool.runInteraction(
+ event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -535,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
- yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+ await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..068ad22b30 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
import logging
import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -304,7 +302,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str)
Returns:
- Deferred
+ Awaitable
"""
# We need to use an upsert, in case they user had already bound the
# threepid
@@ -952,6 +950,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname=None,
admin=False,
user_type=None,
+ shadow_banned=False,
):
"""Attempts to register an account.
@@ -968,6 +967,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
+ shadow_banned (bool): Whether the user is shadow-banned,
+ i.e. they may be told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
@@ -986,6 +987,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
)
def _register_user(
@@ -999,6 +1001,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1028,6 +1031,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
else:
@@ -1042,6 +1046,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
@@ -1077,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
- ) -> Deferred:
+ ) -> Awaitable:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1345,43 +1350,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def upsert_threepid_validation_session(
- self,
- medium,
- address,
- client_secret,
- send_attempt,
- session_id,
- validated_at=None,
- ):
- """Upsert a threepid validation session
- Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- session_id (str): The id of this validation session
- validated_at (int|None): The unix timestamp in milliseconds of
- when the session was marked as valid
- """
- insertion_values = {
- "medium": medium,
- "address": address,
- "client_secret": client_secret,
- }
-
- if validated_at:
- insertion_values["validated_at"] = validated_at
-
- return self.db_pool.simple_upsert(
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- values={"last_send_attempt": send_attempt},
- insertion_values=insertion_values,
- desc="upsert_threepid_validation_session",
- )
-
def start_or_continue_validation_session(
self,
medium,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..aef08c7e12 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -35,10 +35,6 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..161edbeccb 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
- @defer.inlineCallbacks
- def _count_known_servers(self):
+ async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+ count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
- list_name="event_ids",
- inlineCallbacks=True,
+ cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
- def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
- def get_membership_from_event_ids(
+ async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
- return self.db_pool.simple_select_many_batch(
+ return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..991233a9bc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
- inlineCallbacks=True,
)
- def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..8ccfb8fc46 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,15 +39,17 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-from typing import Optional
+from typing import Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine
-):
+ direction: str,
+ column_names: Tuple[str, str],
+ from_token: Optional[Tuple[int, int]],
+ to_token: Optional[Tuple[int, int]],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -90,21 +96,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction (str): Whether we're paginating backwards("b") or
- forwards ("f").
- column_names (tuple[str, str]): The column names to bound. Must *not*
- be user defined as these get inserted directly into the SQL
- statement without escapes.
- from_token (tuple[int, int]|None): The start point for the pagination.
- This is an exclusive minimum bound if direction is "f", and an
- inclusive maximum bound if direction is "b".
- to_token (tuple[int, int]|None): The endpoint point for the pagination.
- This is an inclusive maximum bound if direction is "f", and an
- exclusive minimum bound if direction is "b".
+ direction: Whether we're paginating backwards("b") or forwards ("f").
+ column_names: The column names to bound. Must *not* be user defined as
+ these get inserted directly into the SQL statement without escapes.
+ from_token: The start point for the pagination. This is an exclusive
+ minimum bound if direction is "f", and an inclusive maximum bound if
+ direction is "b".
+ to_token: The endpoint point for the pagination. This is an inclusive
+ maximum bound if direction is "f", and an exclusive minimum bound if
+ direction is "b".
engine: The database engine to generate the clauses for
Returns:
- str: The sql expression
+ The sql expression
"""
assert direction in ("b", "f")
@@ -132,7 +136,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+ bound: str,
+ column_names: Tuple[str, str],
+ values: Tuple[Optional[int], int],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually.
Args:
- bound (str): The comparison operator to use. One of ">", "<", ">=",
+ bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
- names (tuple[str, str]): The column names. Must *not* be user defined
+ names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
- values (tuple[int|None, int]): The values to bound the columns by. If
+ values: The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
- str
+ The SQL statement
"""
assert bound in (">", "<", ">=", "<=")
@@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
)
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self):
raise NotImplementedError()
- @defer.inlineCallbacks
- def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order="DESC"
- ):
+ async def get_room_events_stream_for_rooms(
+ self,
+ room_ids: Iterable[str],
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_ids
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[dict[str,tuple[list[FrozenEvent], str]]]
- A map from room id to a tuple containing:
- - list of recent events in the room
- - stream ordering key for the start of the chunk of events returned.
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
- room_ids = yield self._events_stream_cache.get_entities_changed(
- room_ids, from_id
- )
+ room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
if not room_ids:
return {}
@@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
- res = yield make_deferred_yieldable(
+ res = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_key)
}
- @defer.inlineCallbacks
- def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order="DESC"
- ):
+ async def get_room_events_stream_for_room(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_id
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
- events (in ascending order) and the token from the start of
- the chunk of events returned.
+ The list of events (in ascending order) and the token from the start
+ of the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
+ has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed:
return [], from_key
@@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+ rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- @defer.inlineCallbacks
- def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(self, user_id, from_key, to_key):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
+ rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
- @defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token):
+ async def get_recent_events_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
- events and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of events and a token pointing to the start of the returned
+ events. The events returned are in ascending order.
"""
- rows, token = yield self.get_recent_event_ids_for_room(
+ rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
- @defer.inlineCallbacks
- def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ async def get_recent_event_ids_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
- _EventDictReturn and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of _EventDictReturn and a token pointing to the start of the
+ returned events. The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
- def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
"""Gets details of the first event in a room at or before a stream ordering
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
Deferred[(int, int, str)]:
@@ -574,55 +582,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
- def get_stream_token_for_event(self, event_id):
+ async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "s%d" stream token.
+ A "s%d" stream token.
"""
- return self.db_pool.simple_select_one_onecol(
+ row = await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
- ).addCallback(lambda row: "s%d" % (row,))
+ )
+ return "s%d" % (row,)
- def get_topological_token_for_event(self, event_id):
+ async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "t%d-%d" topological token.
+ A "t%d-%d" topological token.
"""
- return self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
- ).addCallback(
- lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
+ return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
- def get_max_topological_token(self, room_id, stream_key):
+ async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
"""Get the max topological token in a room before the given stream
ordering.
Args:
- room_id (str)
- stream_key (int)
+ room_id
+ stream_key
Returns:
- Deferred[int]
+ The maximum topological token.
"""
sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self.db_pool.execute(
+ row = await self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
- ).addCallback(lambda r: r[0][0] if r else 0)
+ )
+ return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
@@ -634,16 +643,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
- def _set_before_and_after(events, rows, topo_order=True):
+ def _set_before_and_after(
+ events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+ ):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
- events (list[FrozenEvent])
- rows (list[_EventDictReturn])
- topo_order (bool): Whether the events were ordered topologically
- or by stream ordering. If true then all rows should have a non
- null topological_ordering.
+ events
+ rows
+ topo_order: Whether the events were ordered topologically or by stream
+ ordering. If true then all rows should have a non null
+ topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -656,25 +667,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
- @defer.inlineCallbacks
- def get_events_around(
- self, room_id, event_id, before_limit, after_limit, event_filter=None
- ):
+ async def get_events_around(
+ self,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter] = None,
+ ) -> dict:
"""Retrieve events and pagination tokens around a given event in a
room.
-
- Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
-
- Returns:
- dict
"""
- results = yield self.db_pool.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -684,11 +689,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self.get_events_as_list(
+ events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
- events_after = yield self.get_events_as_list(
+ events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -700,17 +705,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
- self, txn, room_id, event_id, before_limit, after_limit, event_filter
- ):
+ self,
+ txn,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter],
+ ) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
+ room_id
+ event_id
+ before_limit
+ after_limit
+ event_filter
Returns:
dict
@@ -758,22 +769,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- @defer.inlineCallbacks
- def get_all_new_events_stream(self, from_id, current_id, limit):
+ async def get_all_new_events_stream(
+ self, from_id: int, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
- from_id (int): the stream_ordering of the last event we processed
- current_id (int): the stream_ordering of the most recently processed event
- limit (int): the maximum number of events to return
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
Returns:
- Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
- `next_id` is the next value to pass as `from_id` (it will either be the
- stream_ordering of the last returned event, or, if fewer than `limit` events
- were found, `current_id`.
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
@@ -795,11 +807,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db_pool.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
@@ -817,21 +829,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos",
)
- async def update_federation_out_pos(self, typ, stream_id):
+ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
- return await self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
- def _reset_federation_positions_txn(self, txn):
+ def _reset_federation_positions_txn(self, txn) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -892,39 +904,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id},
)
- def has_room_changed_since(self, room_id, stream_id):
+ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
txn,
- room_id,
- from_token,
- to_token=None,
- direction="b",
- limit=-1,
- event_filter=None,
- ):
+ room_id: str,
+ from_token: RoomStreamToken,
+ to_token: Optional[RoomStreamToken] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token.
Args:
txn
- room_id (str)
- from_token (RoomStreamToken): The token used to stream from
- to_token (RoomStreamToken|None): A token which if given limits the
- results to only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
+ room_id
+ from_token: The token used to stream from
+ to_token: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to
those that match the filter.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
- as a list of _EventDictReturn and a token that points to the end
- of the result set. If no events are returned then the end of the
- stream has been reached (i.e. there are no events between
- `from_token` and `to_token`), or `limit` is zero.
+ A list of _EventDictReturn and a token that points to the end of the
+ result set. If no events are returned then the end of the stream has
+ been reached (i.e. there are no events between `from_token` and
+ `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -1008,35 +1018,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
- @defer.inlineCallbacks
- def paginate_room_events(
- self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
- ):
+ async def paginate_room_events(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: Optional[str] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token.
Args:
- room_id (str)
- from_key (str): The token used to stream from
- to_key (str|None): A token which if given limits the results to
- only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
- those that match the filter.
+ room_id
+ from_key: The token used to stream from
+ to_key: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to those that match the filter.
Returns:
- tuple[list[FrozenEvent], str]: Returns the results as a list of
- events and a token that points to the end of the result set. If no
- events are returned then the end of the stream has been reached
- (i.e. there are no events between `from_key` and `to_key`).
+ The results as a list of events and a token that points to the end
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between `from_key`
+ and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -1047,7 +1060,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -1057,8 +1070,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token()
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..da23fe7355 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
desc="is_user_erased",
).addCallback(operator.truth)
- @cachedList(
- cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
- )
- def are_users_erased(self, user_ids):
+ @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+ async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
- Deferred[dict[str, bool]]:
+ dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
- res = {u: u in erased_users for u in user_ids}
- return res
+ return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):
diff --git a/synapse/types.py b/synapse/types.py
index 9e580f4295..bc36cdde30 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -51,7 +51,15 @@ JsonDict = Dict[str, Any]
class Requester(
namedtuple(
- "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+ "Requester",
+ [
+ "user",
+ "access_token_id",
+ "is_guest",
+ "shadow_banned",
+ "device_id",
+ "app_service",
+ ],
)
):
"""
@@ -62,6 +70,7 @@ class Requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
"""
@@ -77,6 +86,7 @@ class Requester(
"user_id": self.user.to_string(),
"access_token_id": self.access_token_id,
"is_guest": self.is_guest,
+ "shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
}
@@ -101,13 +111,19 @@ class Requester(
user=UserID.from_string(input["user_id"]),
access_token_id=input["access_token_id"],
is_guest=input["is_guest"],
+ shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
)
def create_requester(
- user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+ user_id,
+ access_token_id=None,
+ is_guest=False,
+ shadow_banned=False,
+ device_id=None,
+ app_service=None,
):
"""
Create a new ``Requester`` object
@@ -117,6 +133,7 @@ def create_requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
@@ -125,7 +142,9 @@ def create_requester(
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
- return Requester(user_id, access_token_id, is_guest, device_id, app_service)
+ return Requester(
+ user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+ )
def get_domain_from_id(string):
|