diff --git a/synapse/__init__.py b/synapse/__init__.py
index 5ef34bce40..aa964afb5e 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.46.0"
+__version__ = "1.47.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
index 9de913db88..8e66a38421 100644
--- a/synapse/_scripts/review_recent_signups.py
+++ b/synapse/_scripts/review_recent_signups.py
@@ -20,7 +20,12 @@ from typing import List
import attr
-from synapse.config._base import RootConfig, find_config_files, read_config_files
+from synapse.config._base import (
+ Config,
+ RootConfig,
+ find_config_files,
+ read_config_files,
+)
from synapse.config.database import DatabaseConfig
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.engines import create_engine
@@ -126,7 +131,7 @@ def main():
config_dict,
)
- since_ms = time.time() * 1000 - config.parse_duration(config_args.since)
+ since_ms = time.time() * 1000 - Config.parse_duration(config_args.since)
exclude_users_with_email = config_args.exclude_emails
include_context = not config_args.only_users
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 685d1c25cf..85302163da 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -596,3 +596,10 @@ class ShadowBanError(Exception):
This should be caught and a proper "fake" success response sent to the user.
"""
+
+
+class ModuleFailedException(Exception):
+ """
+ Raised when a module API callback fails, for example because it raised an
+ exception.
+ """
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index bc550ae646..4b0a9b2974 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -18,7 +18,8 @@ import json
from typing import (
TYPE_CHECKING,
Awaitable,
- Container,
+ Callable,
+ Dict,
Iterable,
List,
Optional,
@@ -217,19 +218,19 @@ class FilterCollection:
return self._filter_json
def timeline_limit(self) -> int:
- return self._room_timeline_filter.limit()
+ return self._room_timeline_filter.limit
def presence_limit(self) -> int:
- return self._presence_filter.limit()
+ return self._presence_filter.limit
def ephemeral_limit(self) -> int:
- return self._room_ephemeral_filter.limit()
+ return self._room_ephemeral_filter.limit
def lazy_load_members(self) -> bool:
- return self._room_state_filter.lazy_load_members()
+ return self._room_state_filter.lazy_load_members
def include_redundant_members(self) -> bool:
- return self._room_state_filter.include_redundant_members()
+ return self._room_state_filter.include_redundant_members
def filter_presence(
self, events: Iterable[UserPresenceState]
@@ -276,19 +277,25 @@ class Filter:
def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json
- self.types = self.filter_json.get("types", None)
- self.not_types = self.filter_json.get("not_types", [])
+ self.limit = filter_json.get("limit", 10)
+ self.lazy_load_members = filter_json.get("lazy_load_members", False)
+ self.include_redundant_members = filter_json.get(
+ "include_redundant_members", False
+ )
+
+ self.types = filter_json.get("types", None)
+ self.not_types = filter_json.get("not_types", [])
- self.rooms = self.filter_json.get("rooms", None)
- self.not_rooms = self.filter_json.get("not_rooms", [])
+ self.rooms = filter_json.get("rooms", None)
+ self.not_rooms = filter_json.get("not_rooms", [])
- self.senders = self.filter_json.get("senders", None)
- self.not_senders = self.filter_json.get("not_senders", [])
+ self.senders = filter_json.get("senders", None)
+ self.not_senders = filter_json.get("not_senders", [])
- self.contains_url = self.filter_json.get("contains_url", None)
+ self.contains_url = filter_json.get("contains_url", None)
- self.labels = self.filter_json.get("org.matrix.labels", None)
- self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+ self.labels = filter_json.get("org.matrix.labels", None)
+ self.not_labels = filter_json.get("org.matrix.not_labels", [])
def filters_all_types(self) -> bool:
return "*" in self.not_types
@@ -302,76 +309,95 @@ class Filter:
def check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
+ Args:
+ event: The event, account data, or presence to check against this
+ filter.
+
Returns:
- True if the event matches
+ True if the event matches the filter.
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
# namedtuple type.
if isinstance(event, UserPresenceState):
- sender: Optional[str] = event.user_id
- room_id = None
- ev_type = "m.presence"
- contains_url = False
- labels: List[str] = []
+ user_id = event.user_id
+ field_matchers = {
+ "senders": lambda v: user_id == v,
+ "types": lambda v: "m.presence" == v,
+ }
+ return self._check_fields(field_matchers)
else:
+ content = event.get("content")
+ # Content is assumed to be a dict below, so ensure it is. This should
+ # always be true for events, but account_data has been allowed to
+ # have non-dict content.
+ if not isinstance(content, dict):
+ content = {}
+
sender = event.get("sender", None)
if not sender:
# Presence events had their 'sender' in content.user_id, but are
# now handled above. We don't know if anything else uses this
# form. TODO: Check this and probably remove it.
- content = event.get("content")
- # account_data has been allowed to have non-dict content, so
- # check type first
- if isinstance(content, dict):
- sender = content.get("user_id")
+ sender = content.get("user_id")
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
- content = event.get("content") or {}
# check if there is a string url field in the content for filtering purposes
- contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
- return self.check_fields(room_id, sender, ev_type, labels, contains_url)
+ field_matchers = {
+ "rooms": lambda v: room_id == v,
+ "senders": lambda v: sender == v,
+ "types": lambda v: _matches_wildcard(ev_type, v),
+ "labels": lambda v: v in labels,
+ }
+
+ result = self._check_fields(field_matchers)
+ if not result:
+ return result
+
+ contains_url_filter = self.contains_url
+ if contains_url_filter is not None:
+ contains_url = isinstance(content.get("url"), str)
+ if contains_url_filter != contains_url:
+ return False
+
+ return True
- def check_fields(
- self,
- room_id: Optional[str],
- sender: Optional[str],
- event_type: Optional[str],
- labels: Container[str],
- contains_url: bool,
- ) -> bool:
+ def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool:
"""Checks whether the filter matches the given event fields.
+ Args:
+ field_matchers: A map of attribute name to callable to use for checking
+ particular fields.
+
+ The attribute name and an inverse (not_<attribute name>) must
+ exist on the Filter.
+
+ The callable should return true if the event's value matches the
+ filter's value.
+
Returns:
True if the event fields match
"""
- literal_keys = {
- "rooms": lambda v: room_id == v,
- "senders": lambda v: sender == v,
- "types": lambda v: _matches_wildcard(event_type, v),
- "labels": lambda v: v in labels,
- }
-
- for name, match_func in literal_keys.items():
+
+ for name, match_func in field_matchers.items():
+ # If the event matches one of the disallowed values, reject it.
not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)):
return False
+ # Other the event does not match at least one of the allowed values,
+ # reject it.
allowed_values = getattr(self, name)
if allowed_values is not None:
if not any(map(match_func, allowed_values)):
return False
- contains_url_filter = self.filter_json.get("contains_url")
- if contains_url_filter is not None:
- if contains_url_filter != contains_url:
- return False
-
+ # Otherwise, accept it.
return True
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
@@ -385,10 +411,10 @@ class Filter:
"""
room_ids = set(room_ids)
- disallowed_rooms = set(self.filter_json.get("not_rooms", []))
+ disallowed_rooms = set(self.not_rooms)
room_ids -= disallowed_rooms
- allowed_rooms = self.filter_json.get("rooms", None)
+ allowed_rooms = self.rooms
if allowed_rooms is not None:
room_ids &= set(allowed_rooms)
@@ -397,15 +423,6 @@ class Filter:
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
- def limit(self) -> int:
- return self.filter_json.get("limit", 10)
-
- def lazy_load_members(self) -> bool:
- return self.filter_json.get("lazy_load_members", False)
-
- def include_redundant_members(self) -> bool:
- return self.filter_json.get("include_redundant_members", False)
-
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 6e84b1524f..4486b3bc7d 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -38,9 +38,6 @@ class ConsentURIBuilder:
def __init__(self, hs_config: HomeServerConfig):
if hs_config.key.form_secret is None:
raise ConfigError("form_secret not set in config")
- if hs_config.server.public_baseurl is None:
- raise ConfigError("public_baseurl not set in config")
-
self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
self._public_baseurl = hs_config.server.public_baseurl
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index f4c3f867a8..f2c1028b5d 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -45,6 +45,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
+from synapse.metrics import register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
@@ -351,6 +352,10 @@ async def start(hs: "HomeServer"):
GAIResolver(reactor, getThreadPool=lambda: resolver_threadpool)
)
+ # Register the threadpools with our metrics.
+ register_threadpool("default", reactor.getThreadPool())
+ register_threadpool("gai_resolver", resolver_threadpool)
+
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 2fc848596d..ad20b1d6aa 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -145,6 +145,20 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for event in state.values():
print(json.dumps(event), file=f)
+ def write_knock(self, room_id, event, state):
+ self.write_events(room_id, [event])
+
+ # We write the knock state somewhere else as they aren't full events
+ # and are only a subset of the state at the event.
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ os.makedirs(room_directory, exist_ok=True)
+
+ knock_state = os.path.join(room_directory, "knock_state")
+
+ with open(knock_state, "a") as f:
+ for event in state.values():
+ print(json.dumps(event), file=f)
+
def finished(self):
return self.base_directory
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 51eadf122d..218826741e 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -100,6 +100,7 @@ from synapse.rest.client.register import (
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
@@ -318,6 +319,8 @@ class GenericWorkerServer(HomeServer):
resources.update({CLIENT_API_PREFIX: resource})
resources.update(build_synapse_client_resource_tree(self))
+ resources.update({"/.well-known": well_known_resource(self)})
+
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 93e2299266..336c279a44 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -66,7 +66,7 @@ from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
-from synapse.rest.well_known import WellKnownResource
+from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.util.httpresourcetree import create_resource_tree
@@ -189,7 +189,7 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
- "/.well-known/matrix/client": WellKnownResource(self),
+ "/.well-known": well_known_resource(self),
"/_synapse/admin": AdminRestResource(self),
**build_synapse_client_resource_tree(self),
}
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index b56c2a24df..c533452cab 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -75,10 +75,6 @@ class AccountValidityConfig(Config):
self.account_validity_period * 10.0 / 100.0
)
- if self.account_validity_renew_by_email_enabled:
- if not self.root.server.public_baseurl:
- raise ConfigError("Can't send renewal emails without 'public_baseurl'")
-
# Load account validity templates.
account_validity_template_dir = account_validity_config.get("template_dir")
if account_validity_template_dir is not None:
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 9b58ecf3d8..3f81814043 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -16,7 +16,7 @@ from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
-from ._base import Config, ConfigError
+from ._base import Config
from ._util import validate_config
@@ -35,14 +35,10 @@ class CasConfig(Config):
if self.cas_enabled:
self.cas_server_url = cas_config["server_url"]
- # The public baseurl is required because it is used by the redirect
- # template.
- public_baseurl = self.root.server.public_baseurl
- if not public_baseurl:
- raise ConfigError("cas_config requires a public_baseurl to be set")
-
# TODO Update this to a _synapse URL.
+ public_baseurl = self.root.server.public_baseurl
self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
+
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
required_attributes = cas_config.get("required_attributes") or {}
self.cas_required_attributes = _parsed_required_attributes_def(
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 8ff59aa2f8..afd65fecd3 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -186,11 +186,6 @@ class EmailConfig(Config):
if not self.email_notif_from:
missing.append("email.notif_from")
- # public_baseurl is required to build password reset and validation links that
- # will be emailed to users
- if config.get("public_baseurl") is None:
- missing.append("public_baseurl")
-
if missing:
raise ConfigError(
MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
@@ -296,9 +291,6 @@ class EmailConfig(Config):
if not self.email_notif_from:
missing.append("email.notif_from")
- if config.get("public_baseurl") is None:
- missing.append("public_baseurl")
-
if missing:
raise ConfigError(
"email.enable_notifs is True but required keys are missing: %s"
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 10f5796330..42f113cd24 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -59,8 +59,6 @@ class OIDCConfig(Config):
)
public_baseurl = self.root.server.public_baseurl
- if public_baseurl is None:
- raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback"
@property
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index a3d2a38c4c..5379e80715 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -45,17 +45,6 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
- if (
- self.account_threepid_delegate_msisdn
- and not self.root.server.public_baseurl
- ):
- raise ConfigError(
- "The configuration option `public_baseurl` is required if "
- "`account_threepid_delegate.msisdn` is set, such that "
- "clients know where to submit validation tokens to. Please "
- "configure `public_baseurl`."
- )
-
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
@@ -240,7 +229,7 @@ class RegistrationConfig(Config):
# in on this server.
#
# (By default, no suggestion is made, so it is left up to the client.
- # This setting is ignored unless public_baseurl is also set.)
+ # This setting is ignored unless public_baseurl is also explicitly set.)
#
#default_identity_server: https://matrix.org
@@ -265,8 +254,6 @@ class RegistrationConfig(Config):
# by the Matrix Identity Service API specification:
# https://matrix.org/docs/spec/identity_service/latest
#
- # If a delegate is specified, the config option public_baseurl must also be filled out.
- #
account_threepid_delegates:
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 9c51b6a25a..ba2b0905ff 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -199,14 +199,11 @@ class SAML2Config(Config):
"""
import saml2
- public_baseurl = self.root.server.public_baseurl
- if public_baseurl is None:
- raise ConfigError("saml2_config requires a public_baseurl to be set")
-
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
+ public_baseurl = self.root.server.public_baseurl
metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
response_url = public_baseurl + "_synapse/client/saml2/authn_response"
return {
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ed094bdc44..7bc0030a9e 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -16,6 +16,7 @@ import itertools
import logging
import os.path
import re
+import urllib.parse
from textwrap import indent
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
@@ -262,11 +263,46 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
+ self.serve_server_wellknown = config.get("serve_server_wellknown", False)
+
+ # Whether we should serve a "client well-known":
+ # (a) at .well-known/matrix/client on our client HTTP listener
+ # (b) in the response to /login
+ #
+ # ... which together help ensure that clients use our public_baseurl instead of
+ # whatever they were told by the user.
+ #
+ # For the sake of backwards compatibility with existing installations, this is
+ # True if public_baseurl is specified explicitly, and otherwise False. (The
+ # reasoning here is that we have no way of knowing that the default
+ # public_baseurl is actually correct for existing installations - many things
+ # will not work correctly, but that's (probably?) better than sending clients
+ # to a completely broken URL.
+ self.serve_client_wellknown = False
+
+ public_baseurl = config.get("public_baseurl")
+ if public_baseurl is None:
+ public_baseurl = f"https://{self.server_name}/"
+ logger.info("Using default public_baseurl %s", public_baseurl)
+ else:
+ self.serve_client_wellknown = True
+ if public_baseurl[-1] != "/":
+ public_baseurl += "/"
+ self.public_baseurl = public_baseurl
- self.public_baseurl = config.get("public_baseurl")
- if self.public_baseurl is not None:
- if self.public_baseurl[-1] != "/":
- self.public_baseurl += "/"
+ # check that public_baseurl is valid
+ try:
+ splits = urllib.parse.urlsplit(self.public_baseurl)
+ except Exception as e:
+ raise ConfigError(f"Unable to parse URL: {e}", ("public_baseurl",))
+ if splits.scheme not in ("https", "http"):
+ raise ConfigError(
+ f"Invalid scheme '{splits.scheme}': only https and http are supported"
+ )
+ if splits.query or splits.fragment:
+ raise ConfigError(
+ "public_baseurl cannot contain query parameters or a #-fragment"
+ )
# Whether to enable user presence.
presence_config = config.get("presence") or {}
@@ -772,8 +808,28 @@ class ServerConfig(Config):
# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
# 'listeners' below).
#
+ # Defaults to 'https://<server_name>/'.
+ #
#public_baseurl: https://example.com/
+ # Uncomment the following to tell other servers to send federation traffic on
+ # port 443.
+ #
+ # By default, other servers will try to reach our server on port 8448, which can
+ # be inconvenient in some environments.
+ #
+ # Provided 'https://<server_name>/' on port 443 is routed to Synapse, this
+ # option configures Synapse to serve a file at
+ # 'https://<server_name>/.well-known/matrix/server'. This will tell other
+ # servers to send traffic to port 443 instead.
+ #
+ # See https://matrix-org.github.io/synapse/latest/delegate.html for more
+ # information.
+ #
+ # Defaults to 'false'.
+ #
+ #serve_server_wellknown: true
+
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 11a9b76aa0..60aacb13ea 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -101,13 +101,10 @@ class SSOConfig(Config):
# gracefully to the client). This would make it pointless to ask the user for
# confirmation, since the URL the confirmation page would be showing wouldn't be
# the client's.
- # public_baseurl is an optional setting, so we only add the fallback's URL to the
- # list if it's provided (because we can't figure out what that URL is otherwise).
- if self.root.server.public_baseurl:
- login_fallback_url = (
- self.root.server.public_baseurl + "_matrix/static/client/login"
- )
- self.sso_client_whitelist.append(login_fallback_url)
+ login_fallback_url = (
+ self.root.server.public_baseurl + "_matrix/static/client/login"
+ )
+ self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs):
return """\
@@ -128,11 +125,10 @@ class SSOConfig(Config):
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
- # If public_baseurl is set, then the login fallback page (used by clients
- # that don't natively support the required login flows) is whitelisted in
- # addition to any URLs in this list.
+ # The login fallback page (used by clients that don't natively support the
+ # required login flows) is whitelisted in addition to any URLs in this list.
#
- # By default, this list is empty.
+ # By default, this list contains only the login fallback page.
#
#client_whitelist:
# - https://riot.im/develop
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 462630201d..4507992031 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -63,7 +63,8 @@ class WriterLocations:
Attributes:
events: The instances that write to the event and backfill streams.
- typing: The instance that writes to the typing stream.
+ typing: The instances that write to the typing stream. Currently
+ can only be a single instance.
to_device: The instances that write to the to_device stream. Currently
can only be a single instance.
account_data: The instances that write to the account data streams. Currently
@@ -75,9 +76,15 @@ class WriterLocations:
"""
events = attr.ib(
- default=["master"], type=List[str], converter=_instance_to_list_converter
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
+ )
+ typing = attr.ib(
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
)
- typing = attr.ib(default="master", type=str)
to_device = attr.ib(
default=["master"],
type=List[str],
@@ -217,6 +224,11 @@ class WorkerConfig(Config):
% (instance, stream)
)
+ if len(self.writers.typing) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `typing` messages."
+ )
+
if len(self.writers.to_device) != 1:
raise ConfigError(
"Must only specify one instance to handle `to_device` messages."
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8628e951c4..f641ab7ef5 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -22,6 +22,7 @@ import attr
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
+ get_verify_key,
is_signing_algorithm_supported,
)
from signedjson.sign import (
@@ -30,6 +31,7 @@ from signedjson.sign import (
signature_ids,
verify_signed_json,
)
+from signedjson.types import VerifyKey
from unpaddedbase64 import decode_base64
from twisted.internet import defer
@@ -177,6 +179,8 @@ class Keyring:
clock=hs.get_clock(),
process_batch_callback=self._inner_fetch_key_requests,
)
+ self.verify_key = get_verify_key(hs.signing_key)
+ self.hostname = hs.hostname
async def verify_json_for_server(
self,
@@ -196,6 +200,7 @@ class Keyring:
validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
"""
+
request = VerifyJsonRequest.from_json_object(
server_name,
json_object,
@@ -262,6 +267,11 @@ class Keyring:
Codes.UNAUTHORIZED,
)
+ # If we are the originating server don't fetch verify key for self over federation
+ if verify_request.server_name == self.hostname:
+ await self._process_json(self.verify_key, verify_request)
+ return
+
# Add the keys we need to verify to the queue for retrieval. We queue
# up requests for the same server so we don't end up with many in flight
# requests for the same keys.
@@ -285,35 +295,8 @@ class Keyring:
if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
continue
- verify_key = key_result.verify_key
- json_object = verify_request.get_json_object()
- try:
- verify_signed_json(
- json_object,
- verify_request.server_name,
- verify_key,
- )
- verified = True
- except SignatureVerifyException as e:
- logger.debug(
- "Error verifying signature for %s:%s:%s with key %s: %s",
- verify_request.server_name,
- verify_key.alg,
- verify_key.version,
- encode_verify_key_base64(verify_key),
- str(e),
- )
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s: %s"
- % (
- verify_request.server_name,
- verify_key.alg,
- verify_key.version,
- str(e),
- ),
- Codes.UNAUTHORIZED,
- )
+ await self._process_json(key_result.verify_key, verify_request)
+ verified = True
if not verified:
raise SynapseError(
@@ -322,6 +305,39 @@ class Keyring:
Codes.UNAUTHORIZED,
)
+ async def _process_json(
+ self, verify_key: VerifyKey, verify_request: VerifyJsonRequest
+ ) -> None:
+ """Processes the `VerifyJsonRequest`. Raises if the signature can't be
+ verified.
+ """
+ try:
+ verify_signed_json(
+ verify_request.get_json_object(),
+ verify_request.server_name,
+ verify_key,
+ )
+ except SignatureVerifyException as e:
+ logger.debug(
+ "Error verifying signature for %s:%s:%s with key %s: %s",
+ verify_request.server_name,
+ verify_key.alg,
+ verify_key.version,
+ encode_verify_key_base64(verify_key),
+ str(e),
+ )
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s: %s"
+ % (
+ verify_request.server_name,
+ verify_key.alg,
+ verify_key.version,
+ str(e),
+ ),
+ Codes.UNAUTHORIZED,
+ )
+
async def _inner_fetch_key_requests(
self, requests: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 157669ea88..38f3cf4d33 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -16,8 +16,23 @@
import abc
import os
-from typing import Dict, Optional, Tuple, Type
-
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+
+from typing_extensions import Literal
from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
+if TYPE_CHECKING:
+ from synapse.events.builder import EventBuilder
+
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
# dict to frozen_dicts is expensive.
@@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
-class DictProperty:
+T = TypeVar("T")
+
+
+# DictProperty (and DefaultDictProperty) require the classes they're used with to
+# have a _dict property to pull properties from.
+#
+# TODO _DictPropertyInstance should not include EventBuilder but due to
+# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and
+# DefaultDictProperty get applied to EventBuilder when it is in a Union with
+# EventBase. This is the least invasive hack to get mypy to comply.
+#
+# Note that DictProperty/DefaultDictProperty cannot actually be used with
+# EventBuilder as it lacks a _dict property.
+_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"]
+
+
+class DictProperty(Generic[T]):
"""An object property which delegates to the `_dict` within its parent object."""
__slots__ = ["key"]
@@ -45,12 +79,33 @@ class DictProperty:
def __init__(self, key: str):
self.key = key
- def __get__(self, instance, owner=None):
+ @overload
+ def __get__(
+ self,
+ instance: Literal[None],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> "DictProperty":
+ ...
+
+ @overload
+ def __get__(
+ self,
+ instance: _DictPropertyInstance,
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> T:
+ ...
+
+ def __get__(
+ self,
+ instance: Optional[_DictPropertyInstance],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> Union[T, "DictProperty"]:
# if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value
if instance is None:
return self
try:
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict[self.key]
except KeyError as e1:
# We want this to look like a regular attribute error (mostly so that
@@ -65,10 +120,12 @@ class DictProperty:
"'%s' has no '%s' property" % (type(instance), self.key)
) from e1.__context__
- def __set__(self, instance, v):
+ def __set__(self, instance: _DictPropertyInstance, v: T) -> None:
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
instance._dict[self.key] = v
- def __delete__(self, instance):
+ def __delete__(self, instance: _DictPropertyInstance) -> None:
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
try:
del instance._dict[self.key]
except KeyError as e1:
@@ -77,7 +134,7 @@ class DictProperty:
) from e1.__context__
-class DefaultDictProperty(DictProperty):
+class DefaultDictProperty(DictProperty, Generic[T]):
"""An extension of DictProperty which provides a default if the property is
not present in the parent's _dict.
@@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty):
__slots__ = ["default"]
- def __init__(self, key, default):
+ def __init__(self, key: str, default: T):
super().__init__(key)
self.default = default
- def __get__(self, instance, owner=None):
+ @overload
+ def __get__(
+ self,
+ instance: Literal[None],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> "DefaultDictProperty":
+ ...
+
+ @overload
+ def __get__(
+ self,
+ instance: _DictPropertyInstance,
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> T:
+ ...
+
+ def __get__(
+ self,
+ instance: Optional[_DictPropertyInstance],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> Union[T, "DefaultDictProperty"]:
if instance is None:
return self
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict.get(self.key, self.default)
@@ -111,22 +189,22 @@ class _EventInternalMetadata:
# in the DAG)
self.outlier = False
- out_of_band_membership: bool = DictProperty("out_of_band_membership")
- send_on_behalf_of: str = DictProperty("send_on_behalf_of")
- recheck_redaction: bool = DictProperty("recheck_redaction")
- soft_failed: bool = DictProperty("soft_failed")
- proactively_send: bool = DictProperty("proactively_send")
- redacted: bool = DictProperty("redacted")
- txn_id: str = DictProperty("txn_id")
- token_id: int = DictProperty("token_id")
- historical: bool = DictProperty("historical")
+ out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership")
+ send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of")
+ recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction")
+ soft_failed: DictProperty[bool] = DictProperty("soft_failed")
+ proactively_send: DictProperty[bool] = DictProperty("proactively_send")
+ redacted: DictProperty[bool] = DictProperty("redacted")
+ txn_id: DictProperty[str] = DictProperty("txn_id")
+ token_id: DictProperty[int] = DictProperty("token_id")
+ historical: DictProperty[bool] = DictProperty("historical")
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
- before: RoomStreamToken = DictProperty("before")
- after: RoomStreamToken = DictProperty("after")
- order: Tuple[int, int] = DictProperty("order")
+ before: DictProperty[RoomStreamToken] = DictProperty("before")
+ after: DictProperty[RoomStreamToken] = DictProperty("after")
+ order: DictProperty[Tuple[int, int]] = DictProperty("order")
def get_dict(self) -> JsonDict:
return dict(self._dict)
@@ -162,9 +240,6 @@ class _EventInternalMetadata:
If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.
-
- Returns:
- bool
"""
return self._dict.get("recheck_redaction", False)
@@ -176,32 +251,23 @@ class _EventInternalMetadata:
sent to clients.
2. They should not be added to the forward extremities (and
therefore not to current state).
-
- Returns:
- bool
"""
return self._dict.get("soft_failed", False)
- def should_proactively_send(self):
+ def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and
servers.
This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event.
-
- Returns:
- bool
"""
return self._dict.get("proactively_send", True)
- def is_redacted(self):
+ def is_redacted(self) -> bool:
"""Whether the event has been redacted.
This is used for efficiently checking whether an event has been
marked as redacted without needing to make another database call.
-
- Returns:
- bool
"""
return self._dict.get("redacted", False)
@@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta):
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
- auth_events = DictProperty("auth_events")
- depth = DictProperty("depth")
- content = DictProperty("content")
- hashes = DictProperty("hashes")
- origin = DictProperty("origin")
- origin_server_ts = DictProperty("origin_server_ts")
- prev_events = DictProperty("prev_events")
- redacts = DefaultDictProperty("redacts", None)
- room_id = DictProperty("room_id")
- sender = DictProperty("sender")
- state_key = DictProperty("state_key")
- type = DictProperty("type")
- user_id = DictProperty("sender")
+ depth: DictProperty[int] = DictProperty("depth")
+ content: DictProperty[JsonDict] = DictProperty("content")
+ hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
+ origin: DictProperty[str] = DictProperty("origin")
+ origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
+ redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
+ room_id: DictProperty[str] = DictProperty("room_id")
+ sender: DictProperty[str] = DictProperty("sender")
+ # TODO state_key should be Optional[str], this is generally asserted in Synapse
+ # by calling is_state() first (which ensures this), but it is hard (not possible?)
+ # to properly annotate that calling is_state() asserts that state_key exists
+ # and is non-None.
+ state_key: DictProperty[str] = DictProperty("state_key")
+ type: DictProperty[str] = DictProperty("type")
+ user_id: DictProperty[str] = DictProperty("sender")
@property
def event_id(self) -> str:
raise NotImplementedError()
@property
- def membership(self):
+ def membership(self) -> str:
return self.content["membership"]
- def is_state(self):
+ def is_state(self) -> bool:
return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self) -> JsonDict:
@@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta):
return d
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
return self._dict.get(key, default)
- def get_internal_metadata_dict(self):
+ def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict()
- def get_pdu_json(self, time_now=None) -> JsonDict:
+ def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict:
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta):
return template_json
- def __set__(self, instance, value):
- raise AttributeError("Unrecognized attribute %s" % (instance,))
-
- def __getitem__(self, field):
+ def __getitem__(self, field: str) -> Optional[Any]:
return self._dict[field]
- def __contains__(self, field):
+ def __contains__(self, field: str) -> bool:
return field in self._dict
- def items(self):
+ def items(self) -> List[Tuple[str, Optional[Any]]]:
return list(self._dict.items())
- def keys(self):
+ def keys(self) -> Iterable[str]:
return self._dict.keys()
- def prev_event_ids(self):
+ def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's prev_events
+ The list of event IDs of this event's prev_events
"""
- return [e for e, _ in self.prev_events]
+ return [e for e, _ in self._dict["prev_events"]]
- def auth_event_ids(self):
+ def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's auth_events
+ The list of event IDs of this event's auth_events
"""
- return [e for e, _ in self.auth_events]
+ return [e for e, _ in self._dict["auth_events"]]
- def freeze(self):
+ def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
- def __str__(self):
+ def __str__(self) -> str:
return self.__repr__()
- def __repr__(self):
+ def __repr__(self) -> str:
rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
return (
@@ -443,7 +508,7 @@ class FrozenEventV2(EventBase):
else:
frozen_dict = event_dict
- self._event_id = None
+ self._event_id: Optional[str] = None
super().__init__(
frozen_dict,
@@ -455,7 +520,7 @@ class FrozenEventV2(EventBase):
)
@property
- def event_id(self):
+ def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -465,23 +530,23 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
- def prev_event_ids(self):
+ def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's prev_events
+ The list of event IDs of this event's prev_events
"""
- return self.prev_events
+ return self._dict["prev_events"]
- def auth_event_ids(self):
+ def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's auth_events
+ The list of event IDs of this event's auth_events
"""
- return self.auth_events
+ return self._dict["auth_events"]
class FrozenEventV3(FrozenEventV2):
@@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2):
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
- def event_id(self):
+ def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id
-def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
+def _event_type_from_format_version(
+ format_version: int,
+) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
"""Returns the python type to use to construct an Event object for the
given event format version.
Args:
- format_version (int): The event format version
+ format_version: The event format version
Returns:
type: A type that can be initialized as per the initializer of
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 2a6dabdab6..1bb8ca7145 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap
@@ -36,6 +36,7 @@ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
[str, StateMap[EventBase], str], Awaitable[bool]
]
+ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -152,6 +153,7 @@ class ThirdPartyEventRules:
self._check_visibility_can_be_modified_callbacks: List[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = []
+ self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
def register_third_party_rules_callbacks(
self,
@@ -163,6 +165,7 @@ class ThirdPartyEventRules:
check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None,
+ on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
@@ -181,6 +184,9 @@ class ThirdPartyEventRules:
check_visibility_can_be_modified,
)
+ if on_new_event is not None:
+ self._on_new_event_callbacks.append(on_new_event)
+
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> Tuple[bool, Optional[dict]]:
@@ -227,9 +233,10 @@ class ThirdPartyEventRules:
# This module callback needs a rework so that hacks such as
# this one are not necessary.
raise e
- except Exception as e:
- logger.warning("Failed to run module API callback %s: %s", callback, e)
- continue
+ except Exception:
+ raise ModuleFailedException(
+ "Failed to run `check_event_allowed` module API callback"
+ )
# Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event.
@@ -321,6 +328,31 @@ class ThirdPartyEventRules:
return True
+ async def on_new_event(self, event_id: str) -> None:
+ """Let modules act on events after they've been sent (e.g. auto-accepting
+ invites, etc.)
+
+ Args:
+ event_id: The ID of the event.
+
+ Raises:
+ ModuleFailureError if a callback raised any exception.
+ """
+ # Bail out early without hitting the store if we don't have any callbacks
+ if len(self._on_new_event_callbacks) == 0:
+ return
+
+ event = await self.store.get_event(event_id)
+ state_events = await self._get_state_map_for_room(event.room_id)
+
+ for callback in self._on_new_event_callbacks:
+ try:
+ await callback(event, state_events)
+ except Exception as e:
+ logger.exception(
+ "Failed to run module API callback %s: %s", callback, e
+ )
+
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
"""Given a room ID, return the state events of that room.
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 4d459c17f1..cf86934968 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -55,7 +55,7 @@ class EventValidator:
]
for k in required:
- if not hasattr(event, k):
+ if k not in event:
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 2ab4dec88f..670186f548 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -227,7 +227,7 @@ class FederationClient(FederationBase):
)
async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
+ self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the
given destination server.
@@ -237,6 +237,8 @@ class FederationClient(FederationBase):
room_id: The room_id to backfill.
limit: The maximum number of events to return.
extremities: our current backwards extremities, to backfill from
+ Must be a Collection that is falsy when empty.
+ (Iterable is not enough here!)
"""
logger.debug("backfill extrem=%s", extremities)
@@ -250,11 +252,22 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%r", transaction_data)
+ if not isinstance(transaction_data, dict):
+ # TODO we probably want an exception type specific to federation
+ # client validation.
+ raise TypeError("Backfill transaction_data is not a dict.")
+
+ transaction_data_pdus = transaction_data.get("pdus")
+ if not isinstance(transaction_data_pdus, list):
+ # TODO we probably want an exception type specific to federation
+ # client validation.
+ raise TypeError("transaction_data.pdus is not a list.")
+
room_version = await self.store.get_room_version(room_id)
pdus = [
event_from_pdu_json(p, room_version, outlier=False)
- for p in transaction_data["pdus"]
+ for p in transaction_data_pdus
]
# Check signatures and hash of pdus, removing any from the list that fail checks
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 0d66034f44..9a8758e9a6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -213,6 +213,11 @@ class FederationServer(FederationBase):
self._started_handling_of_staged_events = True
self._handle_old_staged_events()
+ # Start a periodic check for old staged events. This is to handle
+ # the case where locks time out, e.g. if another process gets killed
+ # without dropping its locks.
+ self._clock.looping_call(self._handle_old_staged_events, 60 * 1000)
+
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -295,14 +300,16 @@ class FederationServer(FederationBase):
Returns:
HTTP response code and body
"""
- response = await self.transaction_actions.have_responded(origin, transaction)
+ existing_response = await self.transaction_actions.have_responded(
+ origin, transaction
+ )
- if response:
+ if existing_response:
logger.debug(
"[%s] We've already responded to this request",
transaction.transaction_id,
)
- return response
+ return existing_response
logger.debug("[%s] Transaction is new", transaction.transaction_id)
@@ -632,7 +639,7 @@ class FederationServer(FederationBase):
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
- ) -> Dict[str, Union[EventBase, str]]:
+ ) -> JsonDict:
"""We've received a /make_knock/ request, so we create a partial knock
event for the room and hand that back, along with the room version, to the knocking
homeserver. We do *not* persist or process this event until the other server has
@@ -1230,10 +1237,6 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
- def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
- """Register that the EDU handler is on a different instance than master."""
- self._edu_type_to_instance[edu_type] = [instance_name]
-
def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
) -> None:
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index dc555cca0b..ab935e5a7e 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -149,7 +149,6 @@ class TransactionManager:
)
except HttpResponseException as e:
code = e.code
- response = e.response
set_tag(tags.ERROR, True)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8b247fe206..10b5aa5af8 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,19 @@
import logging
import urllib
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
import ijson
@@ -100,7 +112,7 @@ class TransportLayerClient:
@log_function
async def backfill(
- self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int
+ self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]:
"""Requests `limit` previous PDUs in a given context before list of
PDUs.
@@ -108,7 +120,9 @@ class TransportLayerClient:
Args:
destination
room_id
- event_tuples
+ event_tuples:
+ Must be a Collection that is falsy when empty.
+ (Iterable is not enough here!)
limit
Returns:
@@ -786,7 +800,7 @@ class TransportLayerClient:
@log_function
def join_group(
self, destination: str, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
+ ) -> Awaitable[JsonDict]:
"""Attempts to join a group"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
@@ -1296,14 +1310,17 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
self._coro_state = ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
+ use_float=True,
)
self._coro_auth = ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
+ use_float=True,
)
self._coro_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
+ use_float=True,
)
def write(self, data: bytes) -> int:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a53cd62d3c..be3203ac80 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -90,6 +90,7 @@ class AdminHandler:
Membership.LEAVE,
Membership.BAN,
Membership.INVITE,
+ Membership.KNOCK,
),
)
@@ -122,6 +123,13 @@ class AdminHandler:
invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state)
+ if room.membership == Membership.KNOCK:
+ event_id = room.event_id
+ knock = await self.store.get_event(event_id, allow_none=True)
+ if knock:
+ knock_state = knock.unsigned["knock_room_state"]
+ writer.write_knock(room_id, knock, knock_state)
+
continue
# We only want to bother fetching events up to the last time they
@@ -239,6 +247,20 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def write_knock(
+ self, room_id: str, event: EventBase, state: StateMap[dict]
+ ) -> None:
+ """Write a knock for the room, with associated knock state.
+
+ Args:
+ room_id: The room ID the knock is for.
+ event: The knock event.
+ state: A subset of the state at the knock, with a subset of the
+ event keys (type, state_key content and sender).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 36c206dae6..ddc9105ee9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -34,6 +34,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
+from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -58,6 +59,10 @@ class ApplicationServicesHandler:
self.current_max = 0
self.is_processing = False
+ self._ephemeral_events_linearizer = Linearizer(
+ name="appservice_ephemeral_events"
+ )
+
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
@@ -182,7 +187,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: str,
- new_token: Optional[int],
+ new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
) -> None:
"""
@@ -203,7 +208,7 @@ class ApplicationServicesHandler:
Appservices will only receive ephemeral events that fall within their
registered user and room namespaces.
- new_token: The latest stream token.
+ new_token: The stream token of the event.
users: The users that should be informed of the new event, if any.
"""
if not self.notify_appservices:
@@ -212,6 +217,19 @@ class ApplicationServicesHandler:
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
return
+ # Assert that new_token is an integer (and not a RoomStreamToken).
+ # All of the supported streams that this function handles use an
+ # integer to track progress (rather than a RoomStreamToken - a
+ # vector clock implementation) as they don't support multiple
+ # stream writers.
+ #
+ # As a result, we simply assert that new_token is an integer.
+ # If we do end up needing to pass a RoomStreamToken down here
+ # in the future, using RoomStreamToken.stream (the minimum stream
+ # position) to convert to an ascending integer value should work.
+ # Additional context: https://github.com/matrix-org/synapse/pull/11137
+ assert isinstance(new_token, int)
+
services = [
service
for service in self.store.get_app_services()
@@ -231,14 +249,13 @@ class ApplicationServicesHandler:
self,
services: List[ApplicationService],
stream_key: str,
- new_token: Optional[int],
+ new_token: int,
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
- # Only handle typing if we have the latest token
- if stream_key == "typing_key" and new_token is not None:
+ if stream_key == "typing_key":
# Note that we don't persist the token (via set_type_stream_id_for_appservice)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
@@ -248,26 +265,37 @@ class ApplicationServicesHandler:
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
+ continue
- elif stream_key == "receipt_key":
- events = await self._handle_receipts(service)
- if events:
- self.scheduler.submit_ephemeral_events_for_as(service, events)
-
- # Persist the latest handled stream token for this appservice
- await self.store.set_type_stream_id_for_appservice(
- service, "read_receipt", new_token
+ # Since we read/update the stream position for this AS/stream
+ with (
+ await self._ephemeral_events_linearizer.queue(
+ (service.id, stream_key)
)
+ ):
+ if stream_key == "receipt_key":
+ events = await self._handle_receipts(service, new_token)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(
+ service, events
+ )
+
+ # Persist the latest handled stream token for this appservice
+ await self.store.set_type_stream_id_for_appservice(
+ service, "read_receipt", new_token
+ )
- elif stream_key == "presence_key":
- events = await self._handle_presence(service, users)
- if events:
- self.scheduler.submit_ephemeral_events_for_as(service, events)
+ elif stream_key == "presence_key":
+ events = await self._handle_presence(service, users, new_token)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(
+ service, events
+ )
- # Persist the latest handled stream token for this appservice
- await self.store.set_type_stream_id_for_appservice(
- service, "presence", new_token
- )
+ # Persist the latest handled stream token for this appservice
+ await self.store.set_type_stream_id_for_appservice(
+ service, "presence", new_token
+ )
async def _handle_typing(
self, service: ApplicationService, new_token: int
@@ -304,7 +332,9 @@ class ApplicationServicesHandler:
)
return typing
- async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
+ async def _handle_receipts(
+ self, service: ApplicationService, new_token: Optional[int]
+ ) -> List[JsonDict]:
"""
Return the latest read receipts that the given application service should receive.
@@ -323,6 +353,12 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
+ if new_token is not None and new_token <= from_key:
+ logger.debug(
+ "Rejecting token lower than or equal to stored: %s" % (new_token,)
+ )
+ return []
+
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
@@ -330,7 +366,10 @@ class ApplicationServicesHandler:
return receipts
async def _handle_presence(
- self, service: ApplicationService, users: Collection[Union[str, UserID]]
+ self,
+ service: ApplicationService,
+ users: Collection[Union[str, UserID]],
+ new_token: Optional[int],
) -> List[JsonDict]:
"""
Return the latest presence updates that the given application service should receive.
@@ -353,6 +392,12 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
+ if new_token is not None and new_token <= from_key:
+ logger.debug(
+ "Rejecting token lower than or equal to stored: %s" % (new_token,)
+ )
+ return []
+
for user in users:
if isinstance(user, str):
user = UserID.from_string(user)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d508d7d32a..60e59d11a0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1989,7 +1989,9 @@ class PasswordAuthProvider:
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
- auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+ auth_checkers: Optional[
+ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
+ ] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 8567cb0e00..8ca5f60b1c 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -245,7 +245,7 @@ class DirectoryHandler:
servers = result.servers
else:
try:
- fed_result = await self.federation.make_query(
+ fed_result: Optional[JsonDict] = await self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
args={"room_alias": room_alias.to_string()},
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d0fb2fc7dc..60c11e3d21 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -201,95 +201,19 @@ class E2eKeysHandler:
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
- @trace
- async def do_remote_query(destination: str) -> None:
- """This is called when we are querying the device list of a user on
- a remote homeserver and their device list is not in the device list
- cache. If we share a room with this user and we're not querying for
- specific user we will update the cache with their device list.
- """
-
- destination_query = remote_queries_not_in_cache[destination]
-
- # We first consider whether we wish to update the device list cache with
- # the users device list. We want to track a user's devices when the
- # authenticated user shares a room with the queried user and the query
- # has not specified a particular device.
- # If we update the cache for the queried user we remove them from further
- # queries. We use the more efficient batched query_client_keys for all
- # remaining users
- user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
-
- if device_list:
- continue
-
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
-
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- if self._is_master:
- user_devices = await self.device_handler.device_list_updater.user_device_resync(
- user_id
- )
- else:
- user_devices = await self._user_device_resync_client(
- user_id=user_id
- )
-
- user_devices = user_devices["devices"]
- user_results = results.setdefault(user_id, {})
- for device in user_devices:
- user_results[device["device_id"]] = device["keys"]
- user_ids_updated.append(user_id)
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
-
- if len(destination_query) == len(user_ids_updated):
- # We've updated all the users in the query and we do not need to
- # make any further remote calls.
- return
-
- # Remove all the users from the query which we have updated
- for user_id in user_ids_updated:
- destination_query.pop(user_id)
-
- try:
- remote_result = await self.federation.query_client_keys(
- destination, {"device_keys": destination_query}, timeout=timeout
- )
-
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in destination_query:
- results[user_id] = keys
-
- if "master_keys" in remote_result:
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
-
- if "self_signing_keys" in remote_result:
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
-
- except Exception as e:
- failure = _exception_to_failure(e)
- failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", failure)
-
await make_deferred_yieldable(
defer.gatherResults(
[
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
+ run_in_background(
+ self._query_devices_for_destination,
+ results,
+ cross_signing_keys,
+ failures,
+ destination,
+ queries,
+ timeout,
+ )
+ for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -301,6 +225,121 @@ class E2eKeysHandler:
return ret
+ @trace
+ async def _query_devices_for_destination(
+ self,
+ results: JsonDict,
+ cross_signing_keys: JsonDict,
+ failures: Dict[str, JsonDict],
+ destination: str,
+ destination_query: Dict[str, Iterable[str]],
+ timeout: int,
+ ) -> None:
+ """This is called when we are querying the device list of a user on
+ a remote homeserver and their device list is not in the device list
+ cache. If we share a room with this user and we're not querying for
+ specific user we will update the cache with their device list.
+
+ Args:
+ results: A map from user ID to their device keys, which gets
+ updated with the newly fetched keys.
+ cross_signing_keys: Map from user ID to their cross signing keys,
+ which gets updated with the newly fetched keys.
+ failures: Map of destinations to failures that have occurred while
+ attempting to fetch keys.
+ destination: The remote server to query
+ destination_query: The query dict of devices to query the remote
+ server for.
+ timeout: The timeout for remote HTTP requests.
+ """
+
+ # We first consider whether we wish to update the device list cache with
+ # the users device list. We want to track a user's devices when the
+ # authenticated user shares a room with the queried user and the query
+ # has not specified a particular device.
+ # If we update the cache for the queried user we remove them from further
+ # queries. We use the more efficient batched query_client_keys for all
+ # remaining users
+ user_ids_updated = []
+ for (user_id, device_list) in destination_query.items():
+ if user_id in user_ids_updated:
+ continue
+
+ if device_list:
+ continue
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ continue
+
+ # We've decided we're sharing a room with this user and should
+ # probably be tracking their device lists. However, we haven't
+ # done an initial sync on the device list so we do it now.
+ try:
+ if self._is_master:
+ resync_results = await self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ resync_results = await self._user_device_resync_client(
+ user_id=user_id
+ )
+
+ # Add the device keys to the results.
+ user_devices = resync_results["devices"]
+ user_results = results.setdefault(user_id, {})
+ for device in user_devices:
+ user_results[device["device_id"]] = device["keys"]
+ user_ids_updated.append(user_id)
+
+ # Add any cross signing keys to the results.
+ master_key = resync_results.get("master_key")
+ self_signing_key = resync_results.get("self_signing_key")
+
+ if master_key:
+ cross_signing_keys["master_keys"][user_id] = master_key
+
+ if self_signing_key:
+ cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
+
+ if len(destination_query) == len(user_ids_updated):
+ # We've updated all the users in the query and we do not need to
+ # make any further remote calls.
+ return
+
+ # Remove all the users from the query which we have updated
+ for user_id in user_ids_updated:
+ destination_query.pop(user_id)
+
+ try:
+ remote_result = await self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}, timeout=timeout
+ )
+
+ for user_id, keys in remote_result["device_keys"].items():
+ if user_id in destination_query:
+ results[user_id] = keys
+
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
+
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
+
+ except Exception as e:
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
+
+ return
+
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9584d5bd46..1a1cd93b1a 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -477,7 +477,7 @@ class FederationEventHandler:
@log_function
async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
+ self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
@@ -1643,7 +1643,7 @@ class FederationEventHandler:
event: the event whose auth_events we want
Returns:
- all of the events in `event.auth_events`, after deduplication
+ all of the events listed in `event.auth_events_ids`, after deduplication
Raises:
AuthError if we were unable to fetch the auth_events for any reason.
@@ -1916,7 +1916,7 @@ class FederationEventHandler:
event_pos = PersistedEventPosition(
self._instance_name, event.internal_metadata.stream_ordering
)
- self._notifier.on_new_room_event(
+ await self._notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 7ef8698a5e..3dbe611f95 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -537,10 +537,6 @@ class IdentityHandler:
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- # It is already checked that public_baseurl is configured since this code
- # should only be used if account_threepid_delegate_msisdn is true.
- assert self.hs.config.server.public_baseurl
-
# we need to tell the client to send the token back to us, since it doesn't
# otherwise know where to send it, so add submit_url response parameter
# (see also MSC2078)
@@ -879,6 +875,8 @@ class IdentityHandler:
}
if room_type is not None:
+ invite_config["room_type"] = room_type
+ # TODO The unstable field is deprecated and should be removed in the future.
invite_config["org.matrix.msc3288.room_type"] = room_type
# If a custom web client location is available, include it in the request.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2e024b551f..d4c2a6ab7a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1318,6 +1318,8 @@ class EventCreationHandler:
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
+ assert event.redacts is not None
+
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -1413,6 +1415,8 @@ class EventCreationHandler:
)
if event.type == EventTypes.Redaction:
+ assert event.redacts is not None
+
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -1500,11 +1504,13 @@ class EventCreationHandler:
next_batch_id = event.content.get(
EventContentFields.MSC2716_NEXT_BATCH_ID
)
- conflicting_insertion_event_id = (
- await self.store.get_insertion_event_by_batch_id(
- event.room_id, next_batch_id
+ conflicting_insertion_event_id = None
+ if next_batch_id:
+ conflicting_insertion_event_id = (
+ await self.store.get_insertion_event_id_by_batch_id(
+ event.room_id, next_batch_id
+ )
)
- )
if conflicting_insertion_event_id is not None:
# The current insertion event that we're processing is invalid
# because an insertion event already exists in the room with the
@@ -1537,13 +1543,16 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- def _notify() -> None:
+ async def _notify() -> None:
try:
- self.notifier.on_new_room_event(
+ await self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
- logger.exception("Error notifying about new room event")
+ logger.exception(
+ "Error notifying about new room event %s",
+ event.event_id,
+ )
run_in_background(_notify)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60ff896386..abfe7be0e3 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -438,7 +438,7 @@ class PaginationHandler:
}
state = None
- if event_filter and event_filter.lazy_load_members() and len(events) > 0:
+ if event_filter and event_filter.lazy_load_members and len(events) > 0:
# TODO: remove redundant members
# FIXME: we also care about invite targets etc.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index fdab50da37..3df872c578 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -52,6 +52,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
+from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
@@ -1551,6 +1552,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
include_offline: bool = True,
+ service: Optional[ApplicationService] = None,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index e6c3cf585b..6b5a6ded8b 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -456,7 +456,11 @@ class ProfileHandler:
continue
new_name = profile.get("displayname")
+ if not isinstance(new_name, str):
+ new_name = None
new_avatar = profile.get("avatar_url")
+ if not isinstance(new_avatar, str):
+ new_avatar = None
# We always hit update to update the last_check timestamp
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index cf01d58ea1..969eb3b9b0 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -525,7 +525,7 @@ class RoomCreationHandler:
):
await self.room_member_handler.update_membership(
requester,
- UserID.from_string(old_event["state_key"]),
+ UserID.from_string(old_event.state_key),
new_room_id,
"ban",
ratelimit=False,
@@ -1173,7 +1173,7 @@ class RoomContextHandler:
else:
last_event_id = event_id
- if event_filter and event_filter.lazy_load_members():
+ if event_filter and event_filter.lazy_load_members:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 2f5a3e4d19..0723286383 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -355,7 +355,7 @@ class RoomBatchHandler:
for (event, context) in reversed(events_to_persist):
await self.event_creation_handler.handle_new_client_event(
await self.create_requester_for_user_id_from_app_service(
- event["sender"], app_service_requester.app_service
+ event.sender, app_service_requester.app_service
),
event=event,
context=context,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 74e6c7eca6..08244b690d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1669,7 +1669,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
# the prev_events consist solely of the previous membership event.
prev_event_ids = [previous_membership_event.event_id]
- auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
+ auth_event_ids = (
+ list(previous_membership_event.auth_event_ids()) + prev_event_ids
+ )
event, context = await self.event_creation_handler.create_event(
requester,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index a3ffa26be8..6e4dff8056 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -249,7 +249,7 @@ class SearchHandler:
)
events.sort(key=lambda e: -rank_map[e.event_id])
- allowed_events = events[: search_filter.limit()]
+ allowed_events = events[: search_filter.limit]
for e in allowed_events:
rm = room_groups.setdefault(
@@ -271,13 +271,13 @@ class SearchHandler:
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
- while len(room_events) < search_filter.limit() and i < 5:
+ while len(room_events) < search_filter.limit and i < 5:
i += 1
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
- search_filter.limit() * 2,
+ search_filter.limit * 2,
pagination_token=pagination_token,
)
@@ -299,9 +299,9 @@ class SearchHandler:
)
room_events.extend(events)
- room_events = room_events[: search_filter.limit()]
+ room_events = room_events[: search_filter.limit]
- if len(results) < search_filter.limit() * 2:
+ if len(results) < search_filter.limit * 2:
pagination_token = None
break
else:
@@ -311,7 +311,7 @@ class SearchHandler:
group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
- if room_events and len(room_events) >= search_filter.limit():
+ if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"]
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c411d69924..22c6174821 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -62,8 +62,8 @@ class FollowerTypingHandler:
if hs.should_send_federation():
self.federation = hs.get_federation_sender()
- if hs.config.worker.writers.typing != hs.get_instance_name():
- hs.get_federation_registry().register_instance_for_edu(
+ if hs.get_instance_name() not in hs.config.worker.writers.typing:
+ hs.get_federation_registry().register_instances_for_edu(
"m.typing",
hs.config.worker.writers.typing,
)
@@ -205,7 +205,7 @@ class TypingWriterHandler(FollowerTypingHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- assert hs.config.worker.writers.typing == hs.get_instance_name()
+ assert hs.get_instance_name() in hs.config.worker.writers.typing
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index c577142268..fbafffd69b 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -84,7 +84,11 @@ class HTTPConnectProxyEndpoint:
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
- def connect(self, protocolFactory: ClientFactory):
+ # Mypy encounters a false positive here: it complains that ClientFactory
+ # is incompatible with IProtocolFactory. But ClientFactory inherits from
+ # Factory, which implements IProtocolFactory. So I think this is a bug
+ # in mypy-zope.
+ def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds
)
@@ -119,13 +123,15 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
self.proxy_creds = proxy_creds
- self.on_connection = defer.Deferred()
+ self.on_connection: "defer.Deferred[None]" = defer.Deferred()
def startedConnecting(self, connector):
return self.wrapped_factory.startedConnecting(connector)
def buildProtocol(self, addr):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+ if wrapped_protocol is None:
+ raise TypeError("buildProtocol produced None instead of a Protocol")
return HTTPConnectProtocol(
self.dst_host,
@@ -235,7 +241,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.host = host
self.port = port
self.proxy_creds = proxy_creds
- self.on_connected = defer.Deferred()
+ self.on_connected: "defer.Deferred[None]" = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 602f93c497..4886626d50 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,8 @@
import logging
import threading
+import traceback
+from typing import Dict, Mapping, Set, Tuple
from prometheus_client.core import Counter, Histogram
@@ -105,19 +107,14 @@ in_flight_requests_db_sched_duration = Counter(
["method", "servlet"],
)
-# The set of all in flight requests, set[RequestMetrics]
-_in_flight_requests = set()
+_in_flight_requests: Set["RequestMetrics"] = set()
# Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock()
-def _get_in_flight_counts():
- """Returns a count of all in flight requests by (method, server_name)
-
- Returns:
- dict[tuple[str, str], int]
- """
+def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
+ """Returns a count of all in flight requests by (method, server_name)"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
with _in_flight_requests_lock:
@@ -127,8 +124,9 @@ def _get_in_flight_counts():
rm.update_metrics()
# Map from (method, name) -> int, the number of in flight requests of that
- # type
- counts = {}
+ # type. The key type is Tuple[str, str], but we leave the length unspecified
+ # for compatability with LaterGauge's annotations.
+ counts: Dict[Tuple[str, ...], int] = {}
for rm in reqs:
key = (rm.method, rm.name)
counts[key] = counts.get(key, 0) + 1
@@ -145,15 +143,21 @@ LaterGauge(
class RequestMetrics:
- def start(self, time_sec, name, method):
- self.start = time_sec
+ def start(self, time_sec: float, name: str, method: str) -> None:
+ self.start_ts = time_sec
self.start_context = current_context()
self.name = name
self.method = method
- # _request_stats records resource usage that we have already added
- # to the "in flight" metrics.
- self._request_stats = self.start_context.get_resource_usage()
+ if self.start_context:
+ # _request_stats records resource usage that we have already added
+ # to the "in flight" metrics.
+ self._request_stats = self.start_context.get_resource_usage()
+ else:
+ logger.error(
+ "Tried to start a RequestMetric from the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
with _in_flight_requests_lock:
_in_flight_requests.add(self)
@@ -169,12 +173,18 @@ class RequestMetrics:
tag = context.tag
if context != self.start_context:
- logger.warning(
+ logger.error(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
)
return
+ else:
+ logger.error(
+ "Trying to stop RequestMetrics in the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
+ return
response_code = str(response_code)
@@ -183,7 +193,7 @@ class RequestMetrics:
response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe(
- time_sec - self.start
+ time_sec - self.start_ts
)
resource_usage = context.get_resource_usage()
@@ -213,6 +223,12 @@ class RequestMetrics:
def update_metrics(self):
"""Updates the in flight metrics with values from this request."""
+ if not self.start_context:
+ logger.error(
+ "Tried to update a RequestMetric from the sentinel context.\n%s",
+ "".join(traceback.format_stack()),
+ )
+ return
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index bdc0187743..d8ae3188b7 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -220,7 +220,7 @@ class _Sentinel:
self.scope = None
self.tag = None
- def __str__(self):
+ def __str__(self) -> str:
return "sentinel"
def copy_to(self, record):
@@ -241,7 +241,7 @@ class _Sentinel:
def record_event_fetch(self, event_count):
pass
- def __bool__(self):
+ def __bool__(self) -> Literal[False]:
return False
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 08895e72ee..4a01b902c2 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -16,6 +16,7 @@
import logging
from functools import wraps
from inspect import getcallargs
+from typing import Callable, TypeVar, cast
_TIME_FUNC_ID = 0
@@ -41,7 +42,10 @@ def _log_debug_as_f(f, msg, msg_args):
logger.handle(record)
-def log_function(f):
+F = TypeVar("F", bound=Callable)
+
+
+def log_function(f: F) -> F:
"""Function decorator that logs every call to that function."""
func_name = f.__name__
@@ -69,4 +73,4 @@ def log_function(f):
return f(*args, **kwargs)
wrapped.__name__ = func_name
- return wrapped
+ return cast(F, wrapped)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index f237b8a236..91ee5c8193 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,7 +20,7 @@ import os
import platform
import threading
import time
-from typing import Callable, Dict, Iterable, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
import attr
from prometheus_client import Counter, Gauge, Histogram
@@ -32,6 +32,7 @@ from prometheus_client.core import (
)
from twisted.internet import reactor
+from twisted.python.threadpool import ThreadPool
import synapse
from synapse.metrics._exposition import (
@@ -67,7 +68,11 @@ class LaterGauge:
labels = attr.ib(hash=False, type=Optional[Iterable[str]])
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
- caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]])
+ caller = attr.ib(
+ type=Callable[
+ [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
+ ]
+ )
def collect(self):
@@ -80,11 +85,11 @@ class LaterGauge:
yield g
return
- if isinstance(calls, dict):
+ if isinstance(calls, (int, float)):
+ g.add_metric([], calls)
+ else:
for k, v in calls.items():
g.add_metric(k, v)
- else:
- g.add_metric([], calls)
yield g
@@ -522,6 +527,42 @@ threepid_send_requests = Histogram(
labelnames=("type", "reason"),
)
+threadpool_total_threads = Gauge(
+ "synapse_threadpool_total_threads",
+ "Total number of threads currently in the threadpool",
+ ["name"],
+)
+
+threadpool_total_working_threads = Gauge(
+ "synapse_threadpool_working_threads",
+ "Number of threads currently working in the threadpool",
+ ["name"],
+)
+
+threadpool_total_min_threads = Gauge(
+ "synapse_threadpool_min_threads",
+ "Minimum number of threads configured in the threadpool",
+ ["name"],
+)
+
+threadpool_total_max_threads = Gauge(
+ "synapse_threadpool_max_threads",
+ "Maximum number of threads configured in the threadpool",
+ ["name"],
+)
+
+
+def register_threadpool(name: str, threadpool: ThreadPool) -> None:
+ """Add metrics for the threadpool."""
+
+ threadpool_total_min_threads.labels(name).set(threadpool.min)
+ threadpool_total_max_threads.labels(name).set(threadpool.max)
+
+ threadpool_total_threads.labels(name).set_function(lambda: len(threadpool.threads))
+ threadpool_total_working_threads.labels(name).set_function(
+ lambda: len(threadpool.working)
+ )
+
class ReactorLastSeenMetric:
def collect(self):
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d707a9325d..6e7f5238fe 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -33,6 +33,7 @@ import jinja2
from twisted.internet import defer
from twisted.web.resource import IResource
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.presence_router import PresenceRouter
from synapse.http.client import SimpleHttpClient
@@ -54,6 +55,7 @@ from synapse.types import (
DomainSpecificString,
JsonDict,
Requester,
+ StateMap,
UserID,
UserInfo,
create_requester,
@@ -88,6 +90,8 @@ __all__ = [
"PRESENCE_ALL_USERS",
"LoginResponse",
"JsonDict",
+ "EventBase",
+ "StateMap",
]
logger = logging.getLogger(__name__)
@@ -625,8 +629,105 @@ class ModuleApi:
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()
+ async def update_room_membership(
+ self,
+ sender: str,
+ target: str,
+ room_id: str,
+ new_membership: str,
+ content: Optional[JsonDict] = None,
+ ) -> EventBase:
+ """Updates the membership of a user to the given value.
+
+ Added in Synapse v1.46.0.
+
+ Args:
+ sender: The user performing the membership change. Must be a user local to
+ this homeserver.
+ target: The user whose membership is changing. This is often the same value
+ as `sender`, but it might differ in some cases (e.g. when kicking a user,
+ the `sender` is the user performing the kick and the `target` is the user
+ being kicked).
+ room_id: The room in which to change the membership.
+ new_membership: The new membership state of `target` after this operation. See
+ https://spec.matrix.org/unstable/client-server-api/#mroommember for the
+ list of allowed values.
+ content: Additional values to include in the resulting event's content.
+
+ Returns:
+ The newly created membership event.
+
+ Raises:
+ RuntimeError if the `sender` isn't a local user.
+ ShadowBanError if a shadow-banned requester attempts to send an invite.
+ SynapseError if the module attempts to send a membership event that isn't
+ allowed, either by the server's configuration (e.g. trying to set a
+ per-room display name that's too long) or by the validation rules around
+ membership updates (e.g. the `membership` value is invalid).
+ """
+ if not self.is_mine(sender):
+ raise RuntimeError(
+ "Tried to send an event as a user that isn't local to this homeserver",
+ )
+
+ requester = create_requester(sender)
+ target_user_id = UserID.from_string(target)
+
+ if content is None:
+ content = {}
+
+ # Set the profile if not already done by the module.
+ if "avatar_url" not in content or "displayname" not in content:
+ try:
+ # Try to fetch the user's profile.
+ profile = await self._hs.get_profile_handler().get_profile(
+ target_user_id.to_string(),
+ )
+ except SynapseError as e:
+ # If the profile couldn't be found, use default values.
+ profile = {
+ "displayname": target_user_id.localpart,
+ "avatar_url": None,
+ }
+
+ if e.code != 404:
+ # If the error isn't 404, it means we tried to fetch the profile over
+ # federation but the remote server responded with a non-standard
+ # status code.
+ logger.error(
+ "Got non-404 error status when fetching profile for %s",
+ target_user_id.to_string(),
+ )
+
+ # Set the profile where it needs to be set.
+ if "avatar_url" not in content:
+ content["avatar_url"] = profile["avatar_url"]
+
+ if "displayname" not in content:
+ content["displayname"] = profile["displayname"]
+
+ event_id, _ = await self._hs.get_room_member_handler().update_membership(
+ requester=requester,
+ target=target_user_id,
+ room_id=room_id,
+ action=new_membership,
+ content=content,
+ )
+
+ # Try to retrieve the resulting event.
+ event = await self._hs.get_datastore().get_event(event_id)
+
+ # update_membership is supposed to always return after the event has been
+ # successfully persisted.
+ assert event is not None
+
+ return event
+
async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
- """Create and send an event into a room. Membership events are currently not supported.
+ """Create and send an event into a room.
+
+ Membership events are not supported by this method. To update a user's membership
+ in a room, please use the `update_room_membership` method instead.
Added in Synapse v1.22.0.
@@ -866,6 +967,52 @@ class ModuleApi:
else:
return []
+ async def get_room_state(
+ self,
+ room_id: str,
+ event_filter: Optional[Iterable[Tuple[str, Optional[str]]]] = None,
+ ) -> StateMap[EventBase]:
+ """Returns the current state of the given room.
+
+ The events are returned as a mapping, in which the key for each event is a tuple
+ which first element is the event's type and the second one is its state key.
+
+ Added in Synapse v1.47.0
+
+ Args:
+ room_id: The ID of the room to get state from.
+ event_filter: A filter to apply when retrieving events. None if no filter
+ should be applied. If provided, must be an iterable of tuples. A tuple's
+ first element is the event type and the second is the state key, or is
+ None if the state key should not be filtered on.
+ An example of a filter is:
+ [
+ ("m.room.member", "@alice:example.com"), # Member event for @alice:example.com
+ ("org.matrix.some_event", ""), # State event of type "org.matrix.some_event"
+ # with an empty string as its state key
+ ("org.matrix.some_other_event", None), # State events of type "org.matrix.some_other_event"
+ # regardless of their state key
+ ]
+ """
+ if event_filter:
+ # If a filter was provided, turn it into a StateFilter and retrieve a filtered
+ # view of the state.
+ state_filter = StateFilter.from_types(event_filter)
+ state_ids = await self._store.get_filtered_current_state_ids(
+ room_id,
+ state_filter,
+ )
+ else:
+ # If no filter was provided, get the whole state. We could also reuse the call
+ # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`,
+ # but get_filtered_current_state_ids isn't cached and `get_current_state_ids`
+ # is, so using the latter when we can is better for perf.
+ state_ids = await self._store.get_current_state_ids(room_id)
+
+ state_events = await self._store.get_events(state_ids.values())
+
+ return {key: state_events[event_id] for key, event_id in state_ids.items()}
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1acd899fab..60e5409895 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -220,6 +220,8 @@ class Notifier:
# down.
self.remote_server_up_callbacks: List[Callable[[str], None]] = []
+ self._third_party_rules = hs.get_third_party_event_rules()
+
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
self._pusher_pool = hs.get_pusherpool()
@@ -267,7 +269,7 @@ class Notifier:
"""
self.replication_callbacks.append(cb)
- def on_new_room_event(
+ async def on_new_room_event(
self,
event: EventBase,
event_pos: PersistedEventPosition,
@@ -275,9 +277,10 @@ class Notifier:
extra_users: Optional[Collection[UserID]] = None,
):
"""Unwraps event and calls `on_new_room_event_args`."""
- self.on_new_room_event_args(
+ await self.on_new_room_event_args(
event_pos=event_pos,
room_id=event.room_id,
+ event_id=event.event_id,
event_type=event.type,
state_key=event.get("state_key"),
membership=event.content.get("membership"),
@@ -285,9 +288,10 @@ class Notifier:
extra_users=extra_users or [],
)
- def on_new_room_event_args(
+ async def on_new_room_event_args(
self,
room_id: str,
+ event_id: str,
event_type: str,
state_key: Optional[str],
membership: Optional[str],
@@ -302,7 +306,10 @@ class Notifier:
listening to the room, and any listeners for the users in the
`extra_users` param.
- The events can be peristed out of order. The notifier will wait
+ This also notifies modules listening on new events via the
+ `on_new_event` callback.
+
+ The events can be persisted out of order. The notifier will wait
until all previous events have been persisted before notifying
the client streams.
"""
@@ -318,6 +325,8 @@ class Notifier:
)
self._notify_pending_new_room_events(max_room_stream_token)
+ await self._third_party_rules.on_new_event(event_id)
+
self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
@@ -374,29 +383,6 @@ class Notifier:
except Exception:
logger.exception("Error notifying application services of event")
- def _notify_app_services_ephemeral(
- self,
- stream_key: str,
- new_token: Union[int, RoomStreamToken],
- users: Optional[Collection[Union[str, UserID]]] = None,
- ) -> None:
- """Notify application services of ephemeral event activity.
-
- Args:
- stream_key: The stream the event came from.
- new_token: The value of the new stream token.
- users: The users that should be informed of the new event, if any.
- """
- try:
- stream_token = None
- if isinstance(new_token, int):
- stream_token = new_token
- self.appservice_handler.notify_interested_services_ephemeral(
- stream_key, stream_token, users or []
- )
- except Exception:
- logger.exception("Error notifying application services of event")
-
def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
self._pusher_pool.on_new_notifications(max_room_stream_token)
@@ -458,12 +444,15 @@ class Notifier:
self.notify_replication()
- # Notify appservices
- self._notify_app_services_ephemeral(
- stream_key,
- new_token,
- users,
- )
+ # Notify appservices.
+ try:
+ self.appservice_handler.notify_interested_services_ephemeral(
+ stream_key,
+ new_token,
+ users,
+ )
+ except Exception:
+ logger.exception("Error notifying application services of event")
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 0622a37ae8..009d8e77b0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -232,6 +232,8 @@ class BulkPushRuleEvaluator:
# that user, as they might not be already joined.
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
+ if not isinstance(display_name, str):
+ display_name = None
if count_as_unread:
# Add an element for the current user if the event needs to be marked as
@@ -268,7 +270,7 @@ def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
- display_name: str,
+ display_name: Optional[str],
cache: Dict[str, bool],
) -> bool:
for cond in conditions:
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 7a8dc63976..7f68092ec5 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -18,7 +18,7 @@ import re
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache
@@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent:
self._value_cache = _flatten_dict(event)
def matches(
- self, condition: Dict[str, Any], user_id: str, display_name: str
+ self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
@@ -172,7 +172,7 @@ class PushRuleEvaluatorForEvent:
return _glob_matches(pattern, haystack)
- def _contains_display_name(self, display_name: str) -> bool:
+ def _contains_display_name(self, display_name: Optional[str]) -> bool:
if not display_name:
return False
@@ -222,7 +222,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
def _flatten_dict(
- d: Union[EventBase, dict],
+ d: Union[EventBase, JsonDict],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
@@ -233,7 +233,7 @@ def _flatten_dict(
for key, value in d.items():
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
- elif hasattr(value, "items"):
+ elif isinstance(value, dict):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 961c17762e..e29ae1e375 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -207,11 +207,12 @@ class ReplicationDataHandler:
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
- self.notifier.on_new_room_event_args(
+ await self.notifier.on_new_room_event_args(
event_pos=event_pos,
max_room_stream_token=max_token,
extra_users=extra_users,
room_id=row.data.room_id,
+ event_id=row.data.event_id,
event_type=row.data.type,
state_key=row.data.state_key,
membership=row.data.membership,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 06fd06fdf3..21293038ef 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -138,7 +138,7 @@ class ReplicationCommandHandler:
if isinstance(stream, TypingStream):
# Only add TypingStream as a source on the instance in charge of
# typing.
- if hs.config.worker.writers.typing == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.typing:
self._streams_to_replicate.append(stream)
continue
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c8b188ae4e..743a01da08 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -328,8 +328,7 @@ class TypingStream(Stream):
ROW_TYPE = TypingStreamRow
def __init__(self, hs: "HomeServer"):
- writer_instance = hs.config.worker.writers.typing
- if writer_instance == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.typing:
# On the writer, query the typing handler
typing_writer_handler = hs.get_typing_writer_handler()
update_function: Callable[
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index e1506deb2b..81e98f81d6 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -25,6 +25,10 @@ from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.rest.admin.background_updates import (
+ BackgroundUpdateEnabledRestServlet,
+ BackgroundUpdateRestServlet,
+)
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
@@ -42,7 +46,6 @@ from synapse.rest.admin.registration_tokens import (
RegistrationTokenRestServlet,
)
from synapse.rest.admin.rooms import (
- DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
@@ -221,7 +224,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
- DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
@@ -249,6 +251,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:
SendServerNoticeServlet(hs).register(http_server)
+ BackgroundUpdateEnabledRestServlet(hs).register(http_server)
+ BackgroundUpdateRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
new file mode 100644
index 0000000000..0d0183bf20
--- /dev/null
+++ b/synapse/rest/admin/background_updates.py
@@ -0,0 +1,107 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class BackgroundUpdateEnabledRestServlet(RestServlet):
+ """Allows temporarily disabling background updates"""
+
+ PATTERNS = admin_patterns("/background_updates/enabled")
+
+ def __init__(self, hs: "HomeServer"):
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
+
+ self.data_stores = hs.get_datastores()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ # We need to check that all configured databases have updates enabled.
+ # (They *should* all be in sync.)
+ enabled = all(db.updates.enabled for db in self.data_stores.databases)
+
+ return 200, {"enabled": enabled}
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ body = parse_json_object_from_request(request)
+
+ enabled = body.get("enabled", True)
+
+ if not isinstance(enabled, bool):
+ raise SynapseError(400, "'enabled' parameter must be a boolean")
+
+ for db in self.data_stores.databases:
+ db.updates.enabled = enabled
+
+ # If we're re-enabling them ensure that we start the background
+ # process again.
+ if enabled:
+ db.updates.start_doing_background_updates()
+
+ return 200, {"enabled": enabled}
+
+
+class BackgroundUpdateRestServlet(RestServlet):
+ """Fetch information about background updates"""
+
+ PATTERNS = admin_patterns("/background_updates/status")
+
+ def __init__(self, hs: "HomeServer"):
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
+
+ self.data_stores = hs.get_datastores()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ # We need to check that all configured databases have updates enabled.
+ # (They *should* all be in sync.)
+ enabled = all(db.updates.enabled for db in self.data_stores.databases)
+
+ current_updates = {}
+
+ for db in self.data_stores.databases:
+ update = db.updates.get_current_update()
+ if not update:
+ continue
+
+ current_updates[db.name()] = {
+ "name": update.name,
+ "total_item_count": update.total_item_count,
+ "total_duration_ms": update.total_duration_ms,
+ "average_items_per_ms": update.average_items_per_ms(),
+ }
+
+ return 200, {"enabled": enabled, "current_updates": current_updates}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index a4823ca6e7..05c5b4bf0c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -46,41 +46,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class DeleteRoomRestServlet(RestServlet):
- """Delete a room from server.
-
- It is a combination and improvement of shutdown and purge room.
-
- Shuts down a room by removing all local users from the room.
- Blocking all future invites and joins to the room is optional.
-
- If desired any local aliases will be repointed to a new room
- created by `new_room_user_id` and kicked users will be auto-
- joined to the new room.
-
- If 'purge' is true, it will remove all traces of a room from the database.
- """
-
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.auth = hs.get_auth()
- self.room_shutdown_handler = hs.get_room_shutdown_handler()
- self.pagination_handler = hs.get_pagination_handler()
-
- async def on_POST(
- self, request: SynapseRequest, room_id: str
- ) -> Tuple[int, JsonDict]:
- return await _delete_room(
- request,
- room_id,
- self.auth,
- self.room_shutdown_handler,
- self.pagination_handler,
- )
-
-
class ListRoomRestServlet(RestServlet):
"""
List all rooms that are known to the homeserver. Results are returned
@@ -218,7 +183,7 @@ class RoomRestServlet(RestServlet):
async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
- return await _delete_room(
+ return await self._delete_room(
request,
room_id,
self.auth,
@@ -226,6 +191,58 @@ class RoomRestServlet(RestServlet):
self.pagination_handler,
)
+ async def _delete_room(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ auth: "Auth",
+ room_shutdown_handler: "RoomShutdownHandler",
+ pagination_handler: "PaginationHandler",
+ ) -> Tuple[int, JsonDict]:
+ requester = await auth.get_user_by_req(request)
+ await assert_user_is_admin(auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+
+ block = content.get("block", False)
+ if not isinstance(block, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'block' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ purge = content.get("purge", True)
+ if not isinstance(purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ force_purge = content.get("force_purge", False)
+ if not isinstance(force_purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'force_purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ ret = await room_shutdown_handler.shutdown_room(
+ room_id=room_id,
+ new_room_user_id=content.get("new_room_user_id"),
+ new_room_name=content.get("room_name"),
+ message=content.get("message"),
+ requester_user_id=requester.user.to_string(),
+ block=block,
+ )
+
+ # Purge room
+ if purge:
+ await pagination_handler.purge_room(room_id, force=force_purge)
+
+ return 200, ret
+
class RoomMembersRestServlet(RestServlet):
"""
@@ -617,55 +634,3 @@ class RoomEventContextServlet(RestServlet):
)
return 200, results
-
-
-async def _delete_room(
- request: SynapseRequest,
- room_id: str,
- auth: "Auth",
- room_shutdown_handler: "RoomShutdownHandler",
- pagination_handler: "PaginationHandler",
-) -> Tuple[int, JsonDict]:
- requester = await auth.get_user_by_req(request)
- await assert_user_is_admin(auth, requester.user)
-
- content = parse_json_object_from_request(request)
-
- block = content.get("block", False)
- if not isinstance(block, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'block' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- purge = content.get("purge", True)
- if not isinstance(purge, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'purge' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- force_purge = content.get("force_purge", False)
- if not isinstance(force_purge, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'force_purge' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- ret = await room_shutdown_handler.shutdown_room(
- room_id=room_id,
- new_room_user_id=content.get("new_room_user_id"),
- new_room_name=content.get("room_name"),
- message=content.get("message"),
- requester_user_id=requester.user.to_string(),
- block=block,
- )
-
- # Purge room
- if purge:
- await pagination_handler.purge_room(room_id, force=force_purge)
-
- return 200, ret
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 9770413c61..2b25b9aad6 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,10 +13,12 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -24,6 +26,8 @@ from synapse.types import JsonDict
from ._base import client_patterns
+pattern = re.compile(r"(?:Element|SchildiChat)/1\.[012]\.")
+
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -52,7 +56,13 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
- body = parse_json_object_from_request(request, allow_empty_body=True)
+ # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
+ user_agent = get_request_user_agent(request)
+ allow_empty_body = False
+ if "Android" in user_agent:
+ if pattern.match(user_agent) or "Riot" in user_agent:
+ allow_empty_body = True
+ body = parse_json_object_from_request(request, allow_empty_body)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
if not isinstance(hidden, bool):
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index ed95189b6d..6a876cfa2f 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -914,7 +914,7 @@ class RoomTypingRestServlet(RestServlet):
# If we're not on the typing writer instance we should scream if we get
# requests.
self._is_typing_writer = (
- hs.config.worker.writers.typing == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.typing
)
async def on_PUT(
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 99f8156ad0..e4c9451ae0 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -112,7 +112,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
# and have the batch connected.
if batch_id_from_query:
corresponding_insertion_event_id = (
- await self.store.get_insertion_event_by_batch_id(
+ await self.store.get_insertion_event_id_by_batch_id(
room_id, batch_id_from_query
)
)
@@ -131,20 +131,22 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_event_ids_from_query
)
+ state_event_ids_at_start = []
# Create and persist all of the state events that float off on their own
# before the batch. These will most likely be all of the invite/member
# state events used to auth the upcoming historical messages.
- state_event_ids_at_start = (
- await self.room_batch_handler.persist_state_events_at_start(
- state_events_at_start=body["state_events_at_start"],
- room_id=room_id,
- initial_auth_event_ids=auth_event_ids,
- app_service_requester=requester,
+ if body["state_events_at_start"]:
+ state_event_ids_at_start = (
+ await self.room_batch_handler.persist_state_events_at_start(
+ state_events_at_start=body["state_events_at_start"],
+ room_id=room_id,
+ initial_auth_event_ids=auth_event_ids,
+ app_service_requester=requester,
+ )
)
- )
- # Update our ongoing auth event ID list with all of the new state we
- # just created
- auth_event_ids.extend(state_event_ids_at_start)
+ # Update our ongoing auth event ID list with all of the new state we
+ # just created
+ auth_event_ids.extend(state_event_ids_at_start)
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
@@ -191,14 +193,17 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
- batch_id_to_connect_to = base_insertion_event["content"][
+ batch_id_to_connect_to = base_insertion_event.content[
EventContentFields.MSC2716_NEXT_BATCH_ID
]
# Also connect the historical event chain to the end of the floating
# state chain, which causes the HS to ask for the state at the start of
- # the batch later.
- prev_event_ids = [state_event_ids_at_start[-1]]
+ # the batch later. If there is no state chain to connect to, just make
+ # the insertion event float itself.
+ prev_event_ids = []
+ if len(state_event_ids_at_start):
+ prev_event_ids = [state_event_ids_at_start[-1]]
# Create and persist all of the historical events as well as insertion
# and batch meta events to make the batch navigable in the DAG.
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index b52a296d8f..8d888f4565 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -72,6 +72,7 @@ class VersionsRestServlet(RestServlet):
"r0.4.0",
"r0.5.0",
"r0.6.0",
+ "r0.6.1",
],
# as per MSC1497:
"unstable_features": {
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index abd88a2d4f..244ba261bb 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -215,6 +215,8 @@ class MediaRepository:
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
+ if not media_type:
+ media_type = "application/octet-stream"
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
@@ -333,6 +335,9 @@ class MediaRepository:
logger.info("Media is quarantined")
raise NotFoundError()
+ if not media_info["media_type"]:
+ media_info["media_type"] = "application/octet-stream"
+
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
@@ -354,6 +359,8 @@ class MediaRepository:
raise e
file_id = media_info["filesystem_id"]
+ if not media_info["media_type"]:
+ media_info["media_type"] = "application/octet-stream"
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@@ -445,7 +452,10 @@ class MediaRepository:
await finish()
- media_type = headers[b"Content-Type"][0].decode("ascii")
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 7dcb1428e4..8162094cf6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -80,7 +80,7 @@ class UploadResource(DirectServeJsonResource):
assert content_type_headers # for mypy
media_type = content_type_headers[0].decode("ascii")
else:
- raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
+ media_type = "application/octet-stream"
# if headers.hasHeader(b"Content-Disposition"):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 7ac01faab4..04b035a1b1 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -21,6 +21,7 @@ from twisted.web.server import Request
from synapse.http.server import set_cors_headers
from synapse.types import JsonDict
from synapse.util import json_encoder
+from synapse.util.stringutils import parse_server_name
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -33,8 +34,7 @@ class WellKnownBuilder:
self._config = hs.config
def get_well_known(self) -> Optional[JsonDict]:
- # if we don't have a public_baseurl, we can't help much here.
- if self._config.server.public_baseurl is None:
+ if not self._config.server.serve_client_wellknown:
return None
result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}}
@@ -47,8 +47,8 @@ class WellKnownBuilder:
return result
-class WellKnownResource(Resource):
- """A Twisted web resource which renders the .well-known file"""
+class ClientWellKnownResource(Resource):
+ """A Twisted web resource which renders the .well-known/matrix/client file"""
isLeaf = 1
@@ -67,3 +67,45 @@ class WellKnownResource(Resource):
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
return json_encoder.encode(r).encode("utf-8")
+
+
+class ServerWellKnownResource(Resource):
+ """Resource for .well-known/matrix/server, redirecting to port 443"""
+
+ isLeaf = 1
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._serve_server_wellknown = hs.config.server.serve_server_wellknown
+
+ host, port = parse_server_name(hs.config.server.server_name)
+
+ # If we've got this far, then https://<server_name>/ must route to us, so
+ # we just redirect the traffic to port 443 instead of 8448.
+ if port is None:
+ port = 443
+
+ self._response = json_encoder.encode({"m.server": f"{host}:{port}"}).encode(
+ "utf-8"
+ )
+
+ def render_GET(self, request: Request) -> bytes:
+ if not self._serve_server_wellknown:
+ request.setResponseCode(404)
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b"404. Is anything ever truly *well* known?\n"
+
+ request.setHeader(b"Content-Type", b"application/json")
+ return self._response
+
+
+def well_known_resource(hs: "HomeServer") -> Resource:
+ """Returns a Twisted web resource which handles '.well-known' requests"""
+ res = Resource()
+ matrix_resource = Resource()
+ res.putChild(b"matrix", matrix_resource)
+
+ matrix_resource.putChild(b"server", ServerWellKnownResource(hs))
+ matrix_resource.putChild(b"client", ClientWellKnownResource(hs))
+
+ return res
diff --git a/synapse/server.py b/synapse/server.py
index 0fbf36ba99..013a7bacaa 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -463,7 +463,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_typing_writer_handler(self) -> TypingWriterHandler:
- if self.config.worker.writers.typing == self.get_instance_name():
+ if self.get_instance_name() in self.config.worker.writers.typing:
return TypingWriterHandler(self)
else:
raise Exception("Workers cannot write typing")
@@ -474,7 +474,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_typing_handler(self) -> FollowerTypingHandler:
- if self.config.worker.writers.typing == self.get_instance_name():
+ if self.get_instance_name() in self.config.worker.writers.typing:
# Use get_typing_writer_handler to ensure that we use the same
# cached version.
return self.get_typing_writer_handler()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 5cf2e12575..1605411b00 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -26,6 +26,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Sequence,
Set,
@@ -246,7 +247,7 @@ class StateHandler:
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
- self, room_id: str, event_ids: List[str]
+ self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
@@ -519,7 +520,7 @@ class StateResolutionHandler:
self,
room_id: str,
room_version: str,
- state_groups_ids: Dict[int, StateMap[str]],
+ state_groups_ids: Mapping[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> _StateCacheEntry:
@@ -703,7 +704,7 @@ class StateResolutionHandler:
def _make_state_cache_entry(
- new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
+ new_state: StateMap[str], state_groups_ids: Mapping[int, StateMap[str]]
) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 82b31d24f1..b9a8ca997e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -100,29 +100,58 @@ class BackgroundUpdater:
] = {}
self._all_done = False
+ # Whether we're currently running updates
+ self._running = False
+
+ # Whether background updates are enabled. This allows us to
+ # enable/disable background updates via the admin API.
+ self.enabled = True
+
+ def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
+ """Returns the current background update, if any."""
+
+ update_name = self._current_background_update
+ if not update_name:
+ return None
+
+ perf = self._background_update_performance.get(update_name)
+ if not perf:
+ perf = BackgroundUpdatePerformance(update_name)
+
+ return perf
+
def start_doing_background_updates(self) -> None:
- run_as_background_process("background_updates", self.run_background_updates)
+ if self.enabled:
+ run_as_background_process("background_updates", self.run_background_updates)
async def run_background_updates(self, sleep: bool = True) -> None:
- logger.info("Starting background schema updates")
- while True:
- if sleep:
- await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+ if self._running or not self.enabled:
+ return
- try:
- result = await self.do_next_background_update(
- self.BACKGROUND_UPDATE_DURATION_MS
- )
- except Exception:
- logger.exception("Error doing update")
- else:
- if result:
- logger.info(
- "No more background updates to do."
- " Unscheduling background update task."
+ self._running = True
+
+ try:
+ logger.info("Starting background schema updates")
+ while self.enabled:
+ if sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+
+ try:
+ result = await self.do_next_background_update(
+ self.BACKGROUND_UPDATE_DURATION_MS
)
- self._all_done = True
- return None
+ except Exception:
+ logger.exception("Error doing update")
+ else:
+ if result:
+ logger.info(
+ "No more background updates to do."
+ " Unscheduling background update task."
+ )
+ self._all_done = True
+ return None
+ finally:
+ self._running = False
async def has_completed_background_updates(self) -> bool:
"""Check if all the background updates have completed
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index fa4e89d35c..d4cab69ebf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -48,6 +48,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
+from synapse.metrics import register_threadpool
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
@@ -104,13 +105,17 @@ def make_pool(
LoggingDatabaseConnection(conn, engine, "on_new_connection")
)
- return adbapi.ConnectionPool(
+ connection_pool = adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
cp_openfun=_on_new_connection,
**db_args,
)
+ register_threadpool(f"database-{db_config.name}", connection_pool.threadpool)
+
+ return connection_pool
+
def make_conn(
db_config: DatabaseConnectionConfig,
@@ -441,6 +446,10 @@ class DatabasePool:
self._check_safe_to_upsert,
)
+ def name(self) -> str:
+ "Return the name of this database"
+ return self._database_config.name
+
def is_running(self) -> bool:
"""Is the database pool currently running"""
return self._db_pool.running
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8143168107..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -488,10 +489,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
+ # We exclude hidden devices (such as cross-signing keys) here as they are
+ # not expected to receive to-device messages.
devices = self.db_pool.simple_select_onecol_txn(
txn,
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
retcol="device_id",
)
@@ -504,10 +507,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not devices:
continue
+ # We exclude hidden devices (such as cross-signing keys) here as they are
+ # not expected to receive to-device messages.
rows = self.db_pool.simple_select_many_txn(
txn,
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
@@ -555,6 +560,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+ REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+ REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -570,6 +577,16 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_DELETED_DEVICES,
+ self._remove_deleted_devices_from_device_inbox,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_HIDDEN_DEVICES,
+ self._remove_hidden_devices_from_device_inbox,
+ )
+
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
@@ -582,6 +599,172 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
+ async def _remove_deleted_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for deleted devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_deleted_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all dead device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) NOT IN (
+ SELECT device_id, user_id FROM devices
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # send more than stream_id to progress
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_DELETED_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_devices_from_device_inbox",
+ _remove_deleted_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_DELETED_DEVICES
+ )
+
+ return number_deleted
+
+ async def _remove_hidden_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for hidden devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_hidden_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all hidden device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) IN (
+ SELECT device_id, user_id FROM devices WHERE hidden = ?
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, True, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # We don't just save the `stream_id` in progress as
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file, as
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_HIDDEN_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_hidden_devices_from_device_inbox",
+ _remove_hidden_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_HIDDEN_DEVICES
+ )
+
+ return number_deleted
+
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a01bf2c5b7..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_ids: the users who were signed
Returns:
- THe new stream ID.
+ The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
@@ -1134,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
raise StoreError(500, "Problem storing device.")
async def delete_device(self, user_id: str, device_id: str) -> None:
- """Delete a device.
+ """Delete a device and its device_inbox.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to delete
"""
- await self.db_pool.simple_delete_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
- desc="delete_device",
- )
- self.device_id_exists_cache.invalidate((user_id, device_id))
+ await self.delete_devices(user_id, [device_id])
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
@@ -1155,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user which owns the devices
device_ids: The IDs of the devices to delete
"""
- await self.db_pool.simple_delete_many(
- table="devices",
- column="device_id",
- iterable=device_ids,
- keyvalues={"user_id": user_id, "hidden": False},
- desc="delete_devices",
- )
+
+ def _delete_devices_txn(txn: LoggingTransaction) -> None:
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="devices",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id, "hidden": False},
+ )
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_inbox",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
+ await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
@@ -1315,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ):
+ ) -> int:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
)
@@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@@ -1794,7 +1795,7 @@ class PersistEventsStore:
)
# Insert an edge for every prev_event connection
- for prev_event_id in event.prev_events:
+ for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn(
txn,
table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ae37901be9..c6bf316d5b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -28,6 +28,7 @@ from typing import (
import attr
from constantly import NamedConstant, Names
+from prometheus_client import Gauge
from typing_extensions import Literal
from twisted.internet import defer
@@ -81,6 +82,12 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+event_fetch_ongoing_gauge = Gauge(
+ "synapse_event_fetch_ongoing",
+ "The number of event fetchers that are running",
+)
+
+
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
@@ -222,6 +229,7 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
@@ -732,28 +740,31 @@ class EventsWorkerStore(SQLBaseStore):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- self._fetch_event_list(conn, event_list)
+ try:
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ break
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+ finally:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -977,6 +988,7 @@ class EventsWorkerStore(SQLBaseStore):
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
should_start = True
else:
should_start = False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
_DEFAULT_CATEGORY_ID = ""
@@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ database.updates.register_background_index_update(
+ update_name="local_group_updates_index",
+ index_name="local_group_updates_stream_id_index",
+ table="local_group_updates",
+ columns=("stream_id",),
+ unique=True,
+ )
+ super().__init__(database, db_conn, hs)
+
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d1dff660b..3d0df0cbd4 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -14,6 +14,7 @@
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -61,7 +62,7 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
- self._live_tokens: Dict[Tuple[str, str], str] = {}
+ self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
@@ -80,10 +81,10 @@ class LockStore(SQLBaseStore):
# We need to take a copy of the tokens dict as dropping the locks will
# cause the dictionary to change.
- tokens = dict(self._live_tokens)
+ locks = dict(self._live_tokens)
- for (lock_name, lock_key), token in tokens.items():
- await self._drop_lock(lock_name, lock_key, token)
+ for lock in locks.values():
+ await lock.release()
logger.info("Dropped locks due to shutdown")
@@ -93,6 +94,11 @@ class LockStore(SQLBaseStore):
used (otherwise the lock will leak).
"""
+ # Check if this process has taken out a lock and if it's still valid.
+ lock = self._live_tokens.get((lock_name, lock_key))
+ if lock and await lock.is_still_valid():
+ return None
+
now = self._clock.time_msec()
token = random_string(6)
@@ -100,7 +106,9 @@ class LockStore(SQLBaseStore):
def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
# We take out the lock if either a) there is no row for the lock
- # already or b) the existing row has timed out.
+ # already, b) the existing row has timed out, or c) the row is
+ # for this instance (which means the process got killed and
+ # restarted)
sql = """
INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?)
@@ -112,6 +120,7 @@ class LockStore(SQLBaseStore):
last_renewed_ts = EXCLUDED.last_renewed_ts
WHERE
worker_locks.last_renewed_ts < ?
+ OR worker_locks.instance_name = EXCLUDED.instance_name
"""
txn.execute(
sql,
@@ -148,11 +157,11 @@ class LockStore(SQLBaseStore):
WHERE
lock_name = ?
AND lock_key = ?
- AND last_renewed_ts < ?
+ AND (last_renewed_ts < ? OR instance_name = ?)
"""
txn.execute(
sql,
- (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+ (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
)
inserted = self.db_pool.simple_upsert_txn_emulated(
@@ -179,9 +188,7 @@ class LockStore(SQLBaseStore):
if not did_lock:
return None
- self._live_tokens[(lock_name, lock_key)] = token
-
- return Lock(
+ lock = Lock(
self._reactor,
self._clock,
self,
@@ -190,6 +197,10 @@ class LockStore(SQLBaseStore):
token=token,
)
+ self._live_tokens[(lock_name, lock_key)] = lock
+
+ return lock
+
async def _is_lock_still_valid(
self, lock_name: str, lock_key: str, token: str
) -> bool:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill,
)
- async def update_presence(self, presence_states):
+ async def update_presence(self, presence_states) -> Tuple[int, int]:
assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index ba7075caa5..dd8e27e226 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def update_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
+ self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
) -> int:
return await self.db_pool.simple_update(
table="remote_profile_cache",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 40760fbd1b..53576ad52f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@
# limitations under the License.
import logging
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
+ where_args: List[Union[str, int]] = [event_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
order,
)
- def _get_recent_references_for_event_txn(txn):
+ def _get_recent_references_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
+ where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
if event_type:
where_clause.append("type = ?")
@@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause=having_clause,
)
- def _get_aggregation_groups_for_event_txn(txn):
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
next_batch = None
@@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1
"""
- def _get_applicable_edit_txn(txn):
+ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
+ return None
edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore):
if not edit_id:
return None
- return await self.get_event(edit_id, allow_none=True)
+ return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
@cached()
async def get_thread_summary(
@@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore):
The number of items in the thread and the most recent response, if any.
"""
- def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+ def _get_thread_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
@@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
- count = txn.fetchone()[0]
+ count = txn.fetchone()[0] # type: ignore[index]
return count, latest_event_id
@@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event = None
if latest_event_id:
- latest_event = await self.get_event(latest_event_id, allow_none=True)
+ latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
return count, latest_event
@@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1;
"""
- def _get_if_user_has_annotated_event(txn):
+ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute(
sql,
(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
limit: maximum amount of rooms to retrieve
order_by: the sort order of the returned list
reverse_order: whether to reverse the room list
- search_term: a string to filter room names by
+ search_term: a string to filter room names,
+ canonical alias and room ids by.
+ Room ID must match exactly. Canonical alias must match a substring of the local part.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
"""
# Filter room names by a string
where_statement = ""
+ search_pattern = []
if search_term:
- where_statement = "WHERE LOWER(state.name) LIKE ?"
+ where_statement = """
+ WHERE LOWER(state.name) LIKE ?
+ OR LOWER(state.canonical_alias) LIKE ?
+ OR state.room_id = ?
+ """
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term.lower() + "%"
+ search_pattern = [
+ "%" + search_term.lower() + "%",
+ "#%" + search_term.lower() + "%:%",
+ search_term,
+ ]
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
)
def _get_rooms_paginate_txn(txn):
- # Execute the data query
- sql_values = (limit, start)
- if search_term:
- # Add the search term into the WHERE clause
- sql_values = (search_term,) + sql_values
- txn.execute(info_sql, sql_values)
+ # Add the search term into the WHERE clause
+ # and execute the data query
+ txn.execute(info_sql, search_pattern + [limit, start])
# Refactor room query data into a structured dictionary
rooms = []
@@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
# Execute the count query
# Add the search term into the WHERE clause if present
- sql_values = (search_term,) if search_term else ()
- txn.execute(count_sql, sql_values)
+ txn.execute(count_sql, search_pattern)
room_count = txn.fetchone()
return rooms, room_count[0]
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index dcbce8fdcf..97b2618437 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -18,7 +18,7 @@ from synapse.storage._base import SQLBaseStore
class RoomBatchStore(SQLBaseStore):
- async def get_insertion_event_by_batch_id(
+ async def get_insertion_event_id_by_batch_id(
self, room_id: str, batch_id: str
) -> Optional[str]:
"""Retrieve a insertion event ID.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
- ):
+ ) -> Dict[str, ProfileInfo]:
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context
)
- async def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(
+ self, room_id, state_entry
+ ) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context,
event=None,
context=None,
- ):
+ ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1629d2a53c..e45adfcb55 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -131,17 +131,9 @@ def prepare_database(
"config==None in prepare_database, but database is not empty"
)
- # if it's a worker app, refuse to upgrade the database, to avoid multiple
- # workers doing it at once.
- if (
- config.worker.worker_app is not None
- and version_info.current_version != SCHEMA_VERSION
- ):
- raise UpgradeDatabaseException(
- OUTDATED_SCHEMA_ON_WORKER_ERROR
- % (SCHEMA_VERSION, version_info.current_version)
- )
-
+ # This should be run on all processes, master or worker. The master will
+ # apply the deltas, while workers will check if any outstanding deltas
+ # exist and raise an PrepareDatabaseException if they do.
_upgrade_existing_database(
cur,
version_info,
@@ -149,6 +141,7 @@ def prepare_database(
config,
databases=databases,
)
+
else:
logger.info("%r: Initialising new database", databases)
@@ -357,6 +350,18 @@ def _upgrade_existing_database(
is_worker = config and config.worker.worker_app is not None
+ # If the schema version needs to be updated, and we are on a worker, we immediately
+ # know to bail out as workers cannot update the database schema. Only one process
+ # must update the database at the time, therefore we delegate this task to the master.
+ if is_worker and current_schema_state.current_version < SCHEMA_VERSION:
+ # If the DB is on an older version than we expect then we refuse
+ # to start the worker (as the main process needs to run first to
+ # update the schema).
+ raise UpgradeDatabaseException(
+ OUTDATED_SCHEMA_ON_WORKER_ERROR
+ % (SCHEMA_VERSION, current_schema_state.current_version)
+ )
+
if (
current_schema_state.compat_version is not None
and current_schema_state.compat_version > SCHEMA_VERSION
diff --git a/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..7b3592dcf0
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
@@ -0,0 +1,22 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- because a device was hidden using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6503, 'remove_hidden_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/04_local_group_updates.sql b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
new file mode 100644
index 0000000000..a178abfe12
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Check index on `local_group_updates.stream_id`.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6504, 'local_group_updates_index', '{}');
diff --git a/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..82f6408b36
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- when a device was deleted using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+-- Remove any existing instances of this job running. It's OK to stop and restart this job,
+-- as it's just deleting entries from a table - no progress will be lost.
+--
+-- This is necessary due a similar migration running the job accidentally
+-- being included in schema version 64 during v1.47.0rc1,rc2. If a
+-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64),
+-- then they would have started running this background update already.
+-- If that update was still running, then simply inserting it again would
+-- cause an SQL failure. So we effectively do an "upsert" here instead.
+
+DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6506, 'remove_deleted_devices_from_device_inbox', '{}');
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5df80ea8e7..96efc5f3e3 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,11 +22,11 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Generic,
Hashable,
Iterable,
- List,
Optional,
Set,
TypeVar,
@@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", set())
+ object.__setattr__(self, "_observers", [])
def callback(r):
object.__setattr__(self, "_result", (True, r))
- while self._observers:
- observer = self._observers.pop()
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
try:
observer.callback(r)
except Exception as e:
@@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
def errback(f):
object.__setattr__(self, "_result", (False, f))
- while self._observers:
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
-
- observer = self._observers.pop()
try:
observer.errback(f)
except Exception as e:
@@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
"""
if not self._result:
d: "defer.Deferred[_T]" = defer.Deferred()
-
- def remove(r):
- self._observers.discard(d)
- return r
-
- d.addBoth(remove)
-
- self._observers.add(d)
+ self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> "List[defer.Deferred[_T]]":
+ def observers(self) -> "Collection[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
|