diff options
Diffstat (limited to 'synapse')
75 files changed, 1807 insertions, 767 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 395e202b89..9840a9d55b 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -16,6 +16,7 @@ import gc import logging import os +import platform import signal import socket import sys @@ -339,7 +340,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon # rest of time. Doing so means less work each GC (hopefully). # # This only works on Python 3.7 - if sys.version_info >= (3, 7): + if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7): gc.collect() gc.freeze() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index c38cf8231f..8f86cecb76 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): stats["daily_active_users"] = await hs.get_datastore().count_daily_users() stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms() + stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms + stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages() + daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages() + stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = await hs.get_datastore().count_daily_messages() + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages r30_results = await hs.get_datastore().count_r30_users() for name, count in r30_results.items(): stats["r30_users_" + name] = count - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 35e5594b73..a851f8801d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -203,11 +203,28 @@ class Config: with open(file_path) as file_stream: return file_stream.read() + def read_template(self, filename: str) -> jinja2.Template: + """Load a template file from disk. + + This function will attempt to load the given template from the default Synapse + template directory. + + Files read are treated as Jinja templates. The templates is not rendered yet + and has autoescape enabled. + + Args: + filename: A template filename to read. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A jinja2 template. + """ + return self.read_templates([filename])[0] + def read_templates( - self, - filenames: List[str], - custom_template_directory: Optional[str] = None, - autoescape: bool = False, + self, filenames: List[str], custom_template_directory: Optional[str] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -215,7 +232,8 @@ class Config: template directory. If `custom_template_directory` is supplied, that directory is tried first. - Files read are treated as Jinja templates. These templates are not rendered yet. + Files read are treated as Jinja templates. The templates are not rendered yet + and have autoescape enabled. Args: filenames: A list of template filenames to read. @@ -223,16 +241,12 @@ class Config: custom_template_directory: A directory to try to look for the templates before using the default Synapse template directory instead. - autoescape: Whether to autoescape variables before inserting them into the - template. - Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. Returns: A list of jinja2 templates. """ - templates = [] search_directories = [self.default_template_dir] # The loader will first look in the custom template directory (if specified) for the @@ -250,7 +264,7 @@ class Config: # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=autoescape) + env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) # Update the environment with our custom filters env.filters.update( @@ -260,12 +274,8 @@ class Config: } ) - for filename in filenames: - # Load the template - template = env.get_template(filename) - templates.append(template) - - return templates + # Load the templates + return [env.get_template(filename) for filename in filenames] class RootConfig: diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 3ccea4b02d..70025b5d60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -19,6 +19,7 @@ from synapse.config import ( password_auth_providers, push, ratelimiting, + redis, registration, repository, room_directory, @@ -53,7 +54,7 @@ class RootConfig: tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig - ratelimit: ratelimiting.RatelimitConfig + ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig @@ -81,6 +82,7 @@ class RootConfig: roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig + redis: redis.RedisConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index cb00958165..9e48f865cc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -28,9 +28,7 @@ class CaptchaConfig(Config): "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", ) - self.recaptcha_template = self.read_templates( - ["recaptcha.html"], autoescape=True - )[0] + self.recaptcha_template = self.read_template("recaptcha.html") def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/cas.py b/synapse/config/cas.py index c7877b4095..b226890c2a 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -30,7 +30,13 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - self.cas_service_url = cas_config["service_url"] + public_base_url = cas_config.get("service_url") or self.public_baseurl + if public_base_url[-1] != "/": + public_base_url += "/" + # TODO Update this to a _synapse URL. + self.cas_service_url = ( + public_base_url + "_matrix/client/r0/login/cas/ticket" + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: @@ -53,10 +59,6 @@ class CasConfig(Config): # #server_url: "https://cas-server.com" - # The public URL of the homeserver. - # - #service_url: "https://homeserver.domain.com:8448" - # The attribute of the CAS response to use as the display name. # # If unset, no displayname will be set. diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 6efa59b110..c47f364b14 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -89,7 +89,7 @@ class ConsentConfig(Config): def read_config(self, config, **kwargs): consent_config = config.get("user_consent") - self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0] + self.terms_template = self.read_template("terms.html") if consent_config is None: return diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 784b416f95..bb122ef182 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,8 +53,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl - self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 14b8836197..def33a60ad 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -24,7 +24,7 @@ class RateLimitConfig: defaults={"per_second": 0.17, "burst_count": 3.0}, ): self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = config.get("burst_count", defaults["burst_count"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) class FederationRateLimitConfig: @@ -102,6 +102,20 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + self.rc_3pid_validation = RateLimitConfig( + config.get("rc_3pid_validation") or {}, + defaults={"per_second": 0.003, "burst_count": 5}, + ) + + self.rc_invites_per_room = RateLimitConfig( + config.get("rc_invites", {}).get("per_room", {}), + defaults={"per_second": 0.3, "burst_count": 10}, + ) + self.rc_invites_per_user = RateLimitConfig( + config.get("rc_invites", {}).get("per_user", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -131,6 +145,9 @@ class RatelimitConfig(Config): # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) + # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. + # - two for ratelimiting how often invites can be sent in a room or to a + # specific user. # # The defaults are as shown below. # @@ -164,7 +181,18 @@ class RatelimitConfig(Config): # remote: # per_second: 0.01 # burst_count: 3 - + # + #rc_3pid_validation: + # per_second: 0.003 + # burst_count: 5 + # + #rc_invites: + # per_room: + # per_second: 0.3 + # burst_count: 10 + # per_user: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 4bfc69cb7a..ac48913a0b 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -176,9 +176,7 @@ class RegistrationConfig(Config): self.session_lifetime = session_lifetime # The success template used during fallback auth. - self.fallback_success_template = self.read_templates( - ["auth_success.html"], autoescape=True - )[0] + self.fallback_success_template = self.read_template("auth_success.html") def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 74b67b230a..14b21796d9 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -125,19 +125,24 @@ class FederationPolicyForHTTPS: self._no_verify_ssl_context = _no_verify_ssl.getContext() self._no_verify_ssl_context.set_info_callback(_context_info_cb) - def get_options(self, host: bytes): + self._should_verify = self._config.federation_verify_certificates + + self._federation_certificate_verification_whitelist = ( + self._config.federation_certificate_verification_whitelist + ) + def get_options(self, host: bytes): # IPolicyForHTTPS.get_options takes bytes, but we want to compare # against the str whitelist. The hostnames in the whitelist are already # IDNA-encoded like the hosts will be here. ascii_host = host.decode("ascii") # Check if certificate verification has been enabled - should_verify = self._config.federation_verify_certificates + should_verify = self._should_verify # Check if we've disabled certificate verification for this host - if should_verify: - for regex in self._config.federation_certificate_verification_whitelist: + if self._should_verify: + for regex in self._federation_certificate_verification_whitelist: if regex.match(ascii_host): should_verify = False break diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d330ae5dbc..40e1451201 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -810,7 +810,7 @@ class FederationClient(FederationBase): "User's homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - elif e.code == 403: + elif e.code in (403, 429): raise e.to_synapse_error() else: raise diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 604cfd1935..643b26ae6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -142,6 +142,8 @@ class FederationSender: self._wake_destinations_needing_catchup, ) + self._external_cache = hs.get_external_cache() + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -197,22 +199,40 @@ class FederationSender: if not event.internal_metadata.should_proactively_send(): return - try: - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() - ) - except Exception: - logger.exception( - "Failed to calculate hosts in room for event: %s", - event.event_id, + destinations = None # type: Optional[Set[str]] + if not event.prev_event_ids(): + # If there are no prev event IDs then the state is empty + # and so no remote servers in the room + destinations = set() + else: + # We check the external cache for the destinations, which is + # stored per state group. + + sg = await self._external_cache.get( + "event_to_prev_state_group", event.event_id ) - return + if sg: + destinations = await self._external_cache.get( + "get_joined_hosts", str(sg) + ) + + if destinations is None: + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return destinations = { d diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 8476256a59..5ecb2da1ac 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING import twisted import twisted.internet.error @@ -22,6 +23,9 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) ACME_REGISTER_FAIL_ERROR = """ @@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC class AcmeHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.reactor = hs.get_reactor() self._acme_domain = hs.config.acme_domain - async def start_listening(self): + async def start_listening(self) -> None: from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug @@ -85,7 +89,7 @@ class AcmeHandler: logger.error(ACME_REGISTER_FAIL_ERROR) raise - async def provision_certificate(self): + async def provision_certificate(self) -> None: logger.warning("Reprovisioning %s", self._acme_domain) @@ -110,5 +114,3 @@ class AcmeHandler: except Exception: logger.exception("Failed saving!") raise - - return True diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index 7294649d71..ae2a9dd9c2 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to imported conditionally. """ import logging +from typing import Dict, Iterable, List import attr +import pem from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from josepy import JWKRSA @@ -36,20 +38,27 @@ from txacme.util import generate_private_key from zope.interface import implementer from twisted.internet import defer +from twisted.internet.interfaces import IReactorTCP from twisted.python.filepath import FilePath from twisted.python.url import URL +from twisted.web.resource import IResource logger = logging.getLogger(__name__) -def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): +def create_issuing_service( + reactor: IReactorTCP, + acme_url: str, + account_key_file: str, + well_known_resource: IResource, +) -> AcmeIssuingService: """Create an ACME issuing service, and attach it to a web Resource Args: reactor: twisted reactor - acme_url (str): URL to use to request certificates - account_key_file (str): where to store the account key - well_known_resource (twisted.web.IResource): web resource for .well-known. + acme_url: URL to use to request certificates + account_key_file: where to store the account key + well_known_resource: web resource for .well-known. we will attach a child resource for "acme-challenge". Returns: @@ -83,18 +92,20 @@ class ErsatzStore: A store that only stores in memory. """ - certs = attr.ib(default=attr.Factory(dict)) + certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) - def store(self, server_name, pem_objects): + def store( + self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] + ) -> defer.Deferred: self.certs[server_name] = [o.as_bytes() for o in pem_objects] return defer.succeed(None) -def load_or_create_client_key(key_file): +def load_or_create_client_key(key_file: str) -> JWKRSA: """Load the ACME account key from a file, creating it if it does not exist. Args: - key_file (str): name of the file to use as the account key + key_file: name of the file to use as the account key """ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't # hardcode the 'client.key' filename diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6f746711ca..a19c556437 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -568,16 +568,6 @@ class AuthHandler(BaseHandler): session.session_id, login_type, result ) except LoginError as e: - if login_type == LoginType.EMAIL_IDENTITY: - # riot used to have a bug where it would request a new - # validation token (thus sending a new email) each time it - # got a 401 with a 'flows' field. - # (https://github.com/vector-im/vector-web/issues/2447). - # - # Grandfather in the old behaviour for now to avoid - # breaking old riot deployments. - raise - # this step failed. Merge the error dict into the response # so that the client can have another go. errordict = e.error_dict() diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 048523ec94..bd35d1fb87 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -100,11 +100,7 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s%s?%s" % ( - self._cas_service_url, - "/_matrix/client/r0/login/cas/ticket", - urllib.parse.urlencode(args), - ) + return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index debb1b4f29..0863154f7a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: + async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ Retrieve the given user's devices @@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device(self, user_id: str, device_id: str) -> JsonDict: """ Retrieve the given device Args: @@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] + device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -946,8 +946,8 @@ class DeviceListUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[Dict[str, Any]], - self_signing_key: Optional[Dict[str, Any]], + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], ) -> List[str]: """Process the given new master and self-signing key for the given remote user. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 929752150d..8f3a6b35a4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( + JsonDict, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class E2eKeysHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -78,7 +82,9 @@ class E2eKeysHandler: ) @trace - async def query_devices(self, query_body, timeout, from_user_id): + async def query_devices( + self, query_body: JsonDict, timeout: int, from_user_id: str + ) -> JsonDict: """ Handle a device key query from a client { @@ -98,12 +104,14 @@ class E2eKeysHandler: } Args: - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] # separate users by domain. # make a map from domain to user_id to device_ids @@ -121,7 +129,8 @@ class E2eKeysHandler: set_tag("remote_key_query", remote_queries) # First get local devices. - failures = {} + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -135,9 +144,10 @@ class E2eKeysHandler: ) # Now attempt to get any remote devices from our local cache. - remote_queries_not_in_cache = {} + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] if remote_queries: - query_list = [] + query_list = [] # type: List[Tuple[str, Optional[str]]] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) @@ -284,15 +294,15 @@ class E2eKeysHandler: return ret async def get_cross_signing_keys_from_cache( - self, query, from_user_id + self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: - query (Iterable[string]) an iterable of user IDs. A dict whose keys + query: an iterable of user IDs. A dict whose keys are user IDs satisfies this, so the query format used for query_devices can be used here. - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. @@ -315,14 +325,12 @@ class E2eKeysHandler: if "self_signing" in user_info: self_signing_keys[user_id] = user_info["self_signing"] - if ( - from_user_id in keys - and keys[from_user_id] is not None - and "user_signing" in keys[from_user_id] - ): - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id: + from_user_key = keys.get(from_user_id) + if from_user_key and "user_signing" in from_user_key: + user_signing_keys[from_user_id] = from_user_key["user_signing"] return { "master_keys": master_keys, @@ -344,9 +352,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] + local_query = [] # type: List[Tuple[str, Optional[str]]] - result_dict = {} + result_dict = {} # type: Dict[str, Dict[str, dict]] for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -380,10 +388,14 @@ class E2eKeysHandler: log_kv(results) return result_dict - async def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys( + self, query_body: Dict[str, Dict[str, Optional[List[str]]]] + ) -> JsonDict: """ Handle a device key query from a federated server """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Optional[List[str]]] res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -397,31 +409,34 @@ class E2eKeysHandler: return ret @trace - async def claim_one_time_keys(self, query, timeout): - local_query = [] - remote_queries = {} + async def claim_one_time_keys( + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + ) -> JsonDict: + local_query = [] # type: List[Tuple[str, str, str]] + remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] - for user_id, device_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in device_keys.items(): + for device_id, algorithm in one_time_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_keys + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) results = await self.store.claim_e2e_one_time_keys(local_query) - json_result = {} - failures = {} + # A map of user ID -> device ID -> key ID -> key. + json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + failures = {} # type: Dict[str, JsonDict] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_bytes) + key_id: json_decoder.decode(json_str) } @trace @@ -468,7 +483,9 @@ class E2eKeysHandler: return {"one_time_keys": json_result, "failures": failures} @tag_args - async def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: time_now = self.clock.time_msec() @@ -543,8 +560,8 @@ class E2eKeysHandler: return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( - self, user_id, device_id, time_now, one_time_keys - ): + self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict + ) -> None: logger.info( "Adding one_time_keys %r for device %r for user %r at %d", one_time_keys.keys(), @@ -585,12 +602,14 @@ class E2eKeysHandler: log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - async def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user( + self, user_id: str, keys: JsonDict + ) -> JsonDict: """Upload signing keys for cross-signing Args: - user_id (string): the user uploading the keys - keys (dict[string, dict]): the signing keys + user_id: the user uploading the keys + keys: the signing keys """ # if a master key is uploaded, then check it. Otherwise, load the @@ -667,16 +686,17 @@ class E2eKeysHandler: return {} - async def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys( + self, user_id: str, signatures: JsonDict + ) -> JsonDict: """Upload device signatures for cross-signing Args: - user_id (string): the user uploading the signatures - signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys. This is the submission from the user; an - exception will be raised if it is malformed. + user_id: the user uploading the signatures + signatures: map of users to devices to signed keys. This is the submission + from the user; an exception will be raised if it is malformed. Returns: - dict: response to be sent back to the client. The response will have + The response to be sent back to the client. The response will have a "failures" key, which will be a dict mapping users to devices to errors for the signatures that failed. Raises: @@ -719,7 +739,9 @@ class E2eKeysHandler: return {"failures": failures} - async def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures( + self, user_id: str, signatures: JsonDict + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -731,15 +753,14 @@ class E2eKeysHandler: signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure - reasons + A tuple of a list of signatures to store, and a map of users to + devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -834,19 +855,24 @@ class E2eKeysHandler: return signature_list, failures def _check_master_key_signature( - self, user_id, master_key_id, signed_master_key, stored_master_key, devices - ): + self, + user_id: str, + master_key_id: str, + signed_master_key: JsonDict, + stored_master_key: JsonDict, + devices: Dict[str, Dict[str, JsonDict]], + ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user whose master key is being checked - master_key_id (string): the ID of the user's master key - signed_master_key (dict): the user's signed master key that was uploaded - stored_master_key (dict): our previously-stored copy of the user's master key - devices (iterable(dict)): the user's devices + user_id: the user whose master key is being checked + master_key_id: the ID of the user's master key + signed_master_key: the user's signed master key that was uploaded + stored_master_key: our previously-stored copy of the user's master key + devices: the user's devices Returns: - list[SignatureListItem]: a list of signatures to store + A list of signatures to store Raises: SynapseError: if a signature is invalid @@ -877,25 +903,26 @@ class E2eKeysHandler: return master_key_signature_list - async def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures( + self, user_id: str, signatures: Dict[str, dict] + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id: the user uploading the keys + signatures: map of users to devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure + A list of signatures to store, and a map of users to devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -983,7 +1010,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None - ): + ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -997,8 +1024,7 @@ class E2eKeysHandler: This affects what signatures are fetched. Returns: - dict, str, VerifyKey: the raw key data, the key ID, and the - signedjson verify key + The raw key data, the key ID, and the signedjson verify key Raises: NotFoundError: if the key is not found @@ -1135,16 +1161,18 @@ class E2eKeysHandler: return desired_key, desired_key_id, desired_verify_key -def _check_cross_signing_key(key, user_id, key_type, signing_key=None): +def _check_cross_signing_key( + key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None +) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. Args: - key (dict): the key data to verify - user_id (str): the user whose key is being checked - key_type (str): the type of key that the key should be - signing_key (VerifyKey): (optional) the signing key that the key should - be signed with. If omitted, signatures will not be checked. + key: the key data to verify + user_id: the user whose key is being checked + key_type: the type of key that the key should be + signing_key: the signing key that the key should be signed with. If + omitted, signatures will not be checked. """ if ( key.get("user_id") != user_id @@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) -def _check_device_signature(user_id, verify_key, signed_device, stored_device): +def _check_device_signature( + user_id: str, + verify_key: VerifyKey, + signed_device: JsonDict, + stored_device: JsonDict, +) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an exception if an error is detected. Args: - user_id (str): the user ID whose signature is being checked - verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the uploaded signed device data - stored_device (dict): our previously stored copy of the device + user_id: the user ID whose signature is being checked + verify_key: the key to verify the device with + signed_device: the uploaded signed device data + stored_device: our previously stored copy of the device Raises: SynapseError: if the signature was invalid or the sent device is not the @@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) -def _exception_to_failure(e): +def _exception_to_failure(e: Exception) -> JsonDict: if isinstance(e, SynapseError): return {"status": e.code, "errcode": e.errcode, "message": str(e)} @@ -1218,7 +1251,7 @@ def _exception_to_failure(e): return {"status": 503, "message": str(e)} -def _one_time_keys_match(old_key_json, new_key): +def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly @@ -1239,16 +1272,16 @@ class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ - signing_key_id = attr.ib() - target_user_id = attr.ib() - target_device_id = attr.ib() - signature = attr.ib() + signing_key_id = attr.ib(type=str) + target_user_id = attr.ib(type=str) + target_device_id = attr.ib(type=str) + signature = attr.ib(type=JsonDict) class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs, e2e_keys_handler): + def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater: iterable=True, ) - async def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. Args: - origin (string): the server that sent the EDU - edu_content (dict): the contents of the EDU + origin: the server that sent the EDU + edu_content: the contents of the EDU """ user_id = edu_content.pop("user_id") @@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater: await self._handle_signing_key_updates(user_id) - async def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id: str) -> None: """Actually handle pending updates. Args: - user_id (string): the user whose updates we are processing + user_id: the user whose updates we are processing """ device_handler = self.e2e_keys_handler.device_handler @@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] + device_ids = [] # type: List[str] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f01b090772..622cae23be 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import ( Codes, @@ -24,8 +25,12 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class E2eRoomKeysHandler: The actual payload of the encrypted keys is completely opaque to the handler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # Used to lock whenever a client is uploading key data. This prevents collisions @@ -48,21 +53,27 @@ class E2eRoomKeysHandler: self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - async def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> List[JsonDict]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. Args: - user_id(str): the user whose keys we're getting - version(str): the version ID of the backup we're getting keys from - room_id(string): room ID to get keys for, for None to get keys for all rooms - session_id(string): session ID to get keys for, for None to get keys for all + user_id: the user whose keys we're getting + version: the version ID of the backup we're getting keys from + room_id: room ID to get keys for, for None to get keys for all rooms + session_id: session ID to get keys for, for None to get keys for all sessions Raises: NotFoundError: if the backup version does not exist Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -86,17 +97,23 @@ class E2eRoomKeysHandler: return results @trace - async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. Args: - user_id(str): the user whose backup we're deleting - version(str): the version ID of the backup we're deleting - room_id(string): room ID to delete keys for, for None to delete keys for all + user_id: the user whose backup we're deleting + version: the version ID of the backup we're deleting + room_id: room ID to delete keys for, for None to delete keys for all rooms - session_id(string): session ID to delete keys for, for None to delete keys + session_id: session ID to delete keys for, for None to delete keys for all sessions Raises: NotFoundError: if the backup version does not exist @@ -128,15 +145,17 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @trace - async def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys( + self, user_id: str, version: str, room_keys: JsonDict + ) -> JsonDict: """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_keys(dict): a nested dict describing the room_keys we're setting: + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_keys: a nested dict describing the room_keys we're setting: { "rooms": { @@ -254,14 +273,16 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @staticmethod - def _should_replace_room_key(current_room_key, room_key): + def _should_replace_room_key( + current_room_key: Optional[JsonDict], room_key: JsonDict + ) -> bool: """ Determine whether to replace a given current_room_key (if any) with a newly uploaded room_key backup Args: - current_room_key (dict): Optional, the current room_key dict if any - room_key (dict): The new room_key dict which may or may not be fit to + current_room_key: Optional, the current room_key dict if any + room_key : The new room_key dict which may or may not be fit to replace the current_room_key Returns: @@ -286,14 +307,14 @@ class E2eRoomKeysHandler: return True @trace - async def create_version(self, user_id, version_info): + async def create_version(self, user_id: str, version_info: JsonDict) -> str: """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. Args: - user_id(str): the user whose backup version we're creating - version_info(dict): metadata about the new version being created + user_id: the user whose backup version we're creating + version_info: metadata about the new version being created { "algorithm": "m.megolm_backup.v1", @@ -301,7 +322,7 @@ class E2eRoomKeysHandler: } Returns: - A deferred of a string that gives the new version number. + The new version number. """ # TODO: Validate the JSON to make sure it has the right keys. @@ -313,17 +334,19 @@ class E2eRoomKeysHandler: ) return new_version - async def get_version_info(self, user_id, version=None): + async def get_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're querying - version(str): Optional; if None gives the most recent version + user_id: the user whose current backup version we're querying + version: Optional; if None gives the most recent version otherwise a historical one. Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of a info dict that gives the info about the new version. + A info dict that gives the info about the new version. { "version": "1234", @@ -346,7 +369,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id, version=None): + async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: @@ -366,17 +389,19 @@ class E2eRoomKeysHandler: raise @trace - async def update_version(self, user_id, version, version_info): + async def update_version( + self, user_id: str, version: str, version_info: JsonDict + ) -> JsonDict: """Update the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're updating - version(str): the backup version we're updating - version_info(dict): the new information about the backup + user_id: the user whose current backup version we're updating + version: the backup version we're updating + version_info: the new information about the backup Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of an empty dict. + An empty dict. """ if "version" not in version_info: version_info["version"] = version diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd8de8696d..dbdfd56ff5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + member_handler.ratelimit_invite(event.room_id, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). @@ -2093,6 +2097,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index df29edeb83..71f11ef94a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -15,9 +15,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Dict, Iterable, List, Set from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, get_domain_from_id +from synapse.types import GroupID, JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ def _create_rerouter(func_name): class GroupsLocalWorkerHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.room_list_handler = hs.get_room_list_handler() @@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler: get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - async def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the group summary for a group. If the group is remote we check that the users have valid attestations. @@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler: return res - async def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users in a group """ if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_users_in_group( + return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) - return res group_server_name = get_domain_from_id(group_id) @@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler: return res - async def get_joined_groups(self, user_id): + async def get_joined_groups(self, user_id: str) -> JsonDict: group_ids = await self.store.get_joined_groups(user_id) return {"groups": group_ids} - async def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: if self.hs.is_mine_id(user_id): result = await self.store.get_publicised_groups_for_user(user_id) @@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler: # TODO: Verify attestations return {"groups": result} - async def bulk_get_publicised_groups(self, user_ids, proxy=True): - destinations = {} + async def bulk_get_publicised_groups( + self, user_ids: Iterable[str], proxy: bool = True + ) -> JsonDict: + destinations = {} # type: Dict[str, Set[str]] local_users = set() for user_id in user_ids: @@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] + failed_results = [] # type: List[str] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( @@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # Ensure attestations get renewed @@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - async def create_group(self, group_id, user_id, content): + async def create_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Create a group """ @@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = None remote_attestation = None else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - content["user_profile"] = await self.profile_handler.get_profile(user_id) - - try: - res = await self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) + raise SynapseError(400, "Unable to create remote groups") is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def join_group(self, group_id, user_id, content): + async def join_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Request to join a group """ if self.is_mine_id(group_id): @@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def accept_invite(self, group_id, user_id, content): + async def accept_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Accept an invite to a group """ if self.is_mine_id(group_id): @@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def invite(self, group_id, user_id, requester_user_id, config): + async def invite( + self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict + ) -> JsonDict: """Invite a user to a group """ content = {"requester_user_id": requester_user_id, "config": config} @@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def on_invite(self, group_id, user_id, content): + async def on_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """One of our users were invited to a group """ # TODO: Support auto join and rejection @@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} async def remove_user_from_group( - self, group_id, user_id, requester_user_id, content - ): + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Remove a user from a group """ if user_id == requester_user_id: @@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def user_removed_from_group(self, group_id, user_id, content): + async def user_removed_from_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> None: """One of our users was removed/kicked from a group """ # TODO: Check if user in group diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..4f7137539b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dfeab09cd..e2a7d567fa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -432,6 +432,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -939,6 +941,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +982,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee27d99135..07b2187eb1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e001e418f9..d335da6f19 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: str, invitee_user_id: str): + """Ratelimit invites by room and by target user. + """ + self._invites_per_room_limiter.ratelimit(room_id) + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): block_invite = True if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), target_id, room_id ): logger.info("Blocking invite due to spam checker") block_invite = True diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 66f1bbcfc4..94062e79cb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -15,23 +15,28 @@ import itertools import logging -from typing import Iterable +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.storage.state import StateFilter +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() @@ -87,13 +92,15 @@ class SearchHandler(BaseHandler): return historical_room_ids - async def search(self, user, content, batch=None): + async def search( + self, user: UserID, content: JsonDict, batch: Optional[str] = None + ) -> JsonDict: """Performs a full text search for a user. Args: - user (UserID) - content (dict): Search parameters - batch (str): The next_batch parameter. Used for pagination. + user + content: Search parameters + batch: The next_batch parameter. Used for pagination. Returns: dict to be returned to the client with results of search @@ -186,7 +193,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] + historical_room_ids = [] # type: List[str] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -209,8 +216,10 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] - room_groups = {} # Holds result of grouping by room, if applicable - sender_group = {} # Holds result of grouping by sender, if applicable + # Holds result of grouping by room, if applicable + room_groups = {} # type: Dict[str, JsonDict] + # Holds result of grouping by sender, if applicable + sender_group = {} # type: Dict[str, JsonDict] # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -254,7 +263,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] + room_events = [] # type: List[EventBase] i = 0 pagination_token = batch_token @@ -418,13 +427,10 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = {e.room_id for e in allowed_events} - for room_id in rooms: + for room_id in {e.room_id for e in allowed_events}: state = await self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) - state_results.values() - # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong @@ -448,9 +454,9 @@ class SearchHandler(BaseHandler): if state_results: s = {} - for room_id, state in state_results.items(): + for room_id, state_events in state_results.items(): s[room_id] = await self._event_serializer.serialize_events( - state, time_now + state_events, time_now ) rooms_cat_res["state"] = s diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a5d67f828f..84af2dde7e 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -13,24 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - self._password_policy_handler = hs.get_password_policy_handler() async def set_password( self, @@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler): password_hash: str, logout_devices: bool, requester: Optional[Requester] = None, - ): + ) -> None: if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index fb4f70e8e2..b3f9875358 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -14,15 +14,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class StateDeltasHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change( + self, + prev_event_id: Optional[str], + event_id: Optional[str], + key_name: str, + public_value: str, + ) -> Optional[bool]: """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index dc62b21c06..d261d7cd4e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -12,13 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple + +from typing_extensions import Counter as CounterType from synapse.api.constants import EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +37,7 @@ class StatsHandler: Heavily derived from UserDirectoryHandler """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -44,7 +50,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None + self.pos = None # type: Optional[int] # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -56,7 +62,7 @@ class StatsHandler: # we start populating stats self.clock.call_later(0, self.notify_new_event) - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when there may be more deltas to process """ if not self.stats_enabled or self._is_processing: @@ -72,7 +78,7 @@ class StatsHandler: run_as_background_process("stats.notify_new_event", process) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_stats_positions() @@ -110,10 +116,10 @@ class StatsHandler: ) for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, {}).update(fields) + room_deltas.setdefault(room_id, Counter()).update(fields) for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, {}).update(fields) + user_deltas.setdefault(user_id, Counter()).update(fields) logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -131,19 +137,20 @@ class StatsHandler: self.pos = max_pos - async def _handle_deltas(self, deltas): + async def _handle_deltas( + self, deltas: Iterable[JsonDict] + ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]: """Called with the state deltas to process Returns: - tuple[dict[str, Counter], dict[str, counter]] Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} - user_to_stats_deltas = {} + room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - room_to_state_updates = {} + room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] for delta in deltas: typ = delta["type"] @@ -173,7 +180,7 @@ class StatsHandler: ) continue - event_content = {} + event_content = {} # type: JsonDict sender = None if event_id is not None: @@ -257,13 +264,13 @@ class StatsHandler: ) if has_changed_joinedness: - delta = +1 if membership == Membership.JOIN else -1 + membership_delta = +1 if membership == Membership.JOIN else -1 user_to_stats_deltas.setdefault(user_id, Counter())[ "joined_rooms" - ] += delta + ] += membership_delta - room_stats_delta["local_users_in_room"] += delta + room_stats_delta["local_users_in_room"] += membership_delta elif typ == EventTypes.Create: room_state["is_federatable"] = ( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e919a8f9ed..3f0dfc7a74 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,13 +15,13 @@ import logging import random from collections import namedtuple -from typing import TYPE_CHECKING, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -65,17 +65,17 @@ class FollowerTypingHandler: ) # map room IDs to serial numbers - self._room_serials = {} + self._room_serials = {} # type: Dict[str, int] # map room IDs to sets of users currently typing - self._room_typing = {} + self._room_typing = {} # type: Dict[str, Set[str]] - self._member_last_federation_poke = {} + self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) - def _reset(self): + def _reset(self) -> None: """Reset the typing handler's data caches. """ # map room IDs to serial numbers @@ -86,7 +86,7 @@ class FollowerTypingHandler: self._member_last_federation_poke = {} self.wheel_timer = WheelTimer(bucket_size=5000) - def _handle_timeouts(self): + def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() @@ -96,7 +96,7 @@ class FollowerTypingHandler: for member in members: self._handle_timeout_for_member(now, member) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: if not self.is_typing(member): # Nothing to do if they're no longer typing return @@ -114,10 +114,10 @@ class FollowerTypingHandler: # each person typing. self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) - def is_typing(self, member): + def is_typing(self, member: RoomMember) -> bool: return member.user_id in self._room_typing.get(member.room_id, []) - async def _push_remote(self, member, typing): + async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: return @@ -148,7 +148,7 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: """Should be called whenever we receive updates for typing stream. """ @@ -178,7 +178,7 @@ class FollowerTypingHandler: async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] - ): + ) -> None: """Process a change in typing of a room from replication, sending EDUs for any local users. """ @@ -194,12 +194,12 @@ class FollowerTypingHandler: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), False) - def get_current_token(self): + def get_current_token(self) -> int: return self._latest_room_serial class TypingWriterHandler(FollowerTypingHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() @@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - self._member_typing_until = {} # clock time we expect to stop + # clock time we expect to stop + self._member_typing_until = {} # type: Dict[RoomMember, int] # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial ) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: super()._handle_timeout_for_member(now, member) if not self.is_typing(member): @@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) return - async def started_typing(self, target_user, requester, room_id, timeout): + async def started_typing( + self, target_user: UserID, requester: Requester, room_id: str, timeout: int + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler): if was_present: # No point sending another notification - return None + return self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, requester, room_id): + async def stopped_typing( + self, target_user: UserID, requester: Requester, room_id: str + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) - def user_left_room(self, user, room_id): + def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) self._stopped_typing(member) - def _stopped_typing(self, member): + def _stopped_typing(self, member: RoomMember) -> None: if member.user_id not in self._room_typing.get(member.room_id, set()): # No point - return None + return self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) self._push_update(member=member, typing=False) - def _push_update(self, member, typing): + def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process( @@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler): self._push_update_local(member=member, typing=typing) - async def _recv_edu(self, origin, content): + async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] @@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler): self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self._push_update_local(member=member, typing=content["typing"]) - def _push_update_local(self, member, typing): + def _push_update_local(self, member: RoomMember, typing: bool) -> None: room_set = self._room_typing.setdefault(member.room_id, set()) if typing: room_set.add(member.user_id) @@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler): changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id - ) + ) # type: Optional[Iterable[str]] if changed_rooms is None: changed_rooms = self._room_serials @@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler): def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: # The writing process should never get updates from replication. raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: @@ -427,7 +432,7 @@ class TypingNotificationEventSource: # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id): + def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", @@ -462,7 +467,9 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: Iterable[str], **kwargs + ) -> Tuple[List[JsonDict], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -478,5 +485,5 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - def get_current_key(self): + def get_current_key(self) -> int: return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d4651c8348..8aedf5072e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler): if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() - # If still None then the initial background update hasn't happened yet - if self.pos is None: - return None - # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): @@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler): if change: # The user joined event = await self.store.get_event(event_id, allow_none=True) + # It isn't expected for this event to not exist, but we + # don't want the entire background process to break. + if event is None: + continue + profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ab586c318c..0538350f38 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -791,7 +791,7 @@ def tag_args(func): @wraps(func) def _tag_args_inner(*args, **kwargs): - argspec = inspect.getargspec(func) + argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): set_tag("ARG_" + arg, args[i]) set_tag("args", args[len(argspec.args) :]) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 4d875dcb91..8a6dcff30d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -267,9 +267,21 @@ class Mailer: fallback_to_members=True, ) - summary_text = await self.make_summary_text( - notifs_by_room, state_by_room, notif_events, user_id, reason - ) + if len(notifs_by_room) == 1: + # Only one room has new stuff + room_id = list(notifs_by_room.keys())[0] + + summary_text = await self.make_summary_text_single_room( + room_id, + notifs_by_room[room_id], + state_by_room[room_id], + notif_events, + user_id, + ) + else: + summary_text = await self.make_summary_text( + notifs_by_room, state_by_room, notif_events, reason + ) template_vars = { "user_display_name": user_display_name, @@ -492,139 +504,178 @@ class Mailer: if "url" in event.content: messagevars["image_url"] = event.content["url"] - async def make_summary_text( + async def make_summary_text_single_room( self, - notifs_by_room: Dict[str, List[Dict[str, Any]]], - room_state_ids: Dict[str, StateMap[str]], + room_id: str, + notifs: List[Dict[str, Any]], + room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], user_id: str, - reason: Dict[str, Any], - ): - if len(notifs_by_room) == 1: - # Only one room has new stuff - room_id = list(notifs_by_room.keys())[0] + ) -> str: + """ + Make a summary text for the email when only a single room has notifications. - # If the room has some kind of name, use it, but we don't - # want the generated-from-names one here otherwise we'll - # end up with, "new message from Bob in the Bob room" - room_name = await calculate_room_name( - self.store, room_state_ids[room_id], user_id, fallback_to_members=False - ) + Args: + room_id: The ID of the room. + notifs: The notifications for this room. + room_state_ids: The state map for the room. + notif_events: A map of event ID -> notification event. + user_id: The user receiving the notification. + + Returns: + The summary text. + """ + # If the room has some kind of name, use it, but we don't + # want the generated-from-names one here otherwise we'll + # end up with, "new message from Bob in the Bob room" + room_name = await calculate_room_name( + self.store, room_state_ids, user_id, fallback_to_members=False + ) - # See if one of the notifs is an invite event for the user - invite_event = None - for n in notifs_by_room[room_id]: - ev = notif_events[n["event_id"]] - if ev.type == EventTypes.Member and ev.state_key == user_id: - if ev.content.get("membership") == Membership.INVITE: - invite_event = ev - break - - if invite_event: - inviter_member_event_id = room_state_ids[room_id].get( - ("m.room.member", invite_event.sender) - ) - inviter_name = invite_event.sender - if inviter_member_event_id: - inviter_member_event = await self.store.get_event( - inviter_member_event_id, allow_none=True - ) - if inviter_member_event: - inviter_name = name_from_member_event(inviter_member_event) - - if room_name is None: - return self.email_subjects.invite_from_person % { - "person": inviter_name, - "app": self.app_name, - } - else: - return self.email_subjects.invite_from_person_to_room % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - } + # See if one of the notifs is an invite event for the user + invite_event = None + for n in notifs: + ev = notif_events[n["event_id"]] + if ev.type == EventTypes.Member and ev.state_key == user_id: + if ev.content.get("membership") == Membership.INVITE: + invite_event = ev + break - sender_name = None - if len(notifs_by_room[room_id]) == 1: - # There is just the one notification, so give some detail - event = notif_events[notifs_by_room[room_id][0]["event_id"]] - if ("m.room.member", event.sender) in room_state_ids[room_id]: - state_event_id = room_state_ids[room_id][ - ("m.room.member", event.sender) - ] - state_event = await self.store.get_event(state_event_id) - sender_name = name_from_member_event(state_event) - - if sender_name is not None and room_name is not None: - return self.email_subjects.message_from_person_in_room % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - } - elif sender_name is not None: - return self.email_subjects.message_from_person % { - "person": sender_name, - "app": self.app_name, - } - else: - # There's more than one notification for this room, so just - # say there are several - if room_name is not None: - return self.email_subjects.messages_in_room % { - "room": room_name, - "app": self.app_name, - } - else: - # If the room doesn't have a name, say who the messages - # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list( - { - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - } - ) - - member_events = await self.store.get_events( - [ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ] - ) - - return self.email_subjects.messages_from_person % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - } - else: - # Stuff's happened in multiple different rooms + if invite_event: + inviter_member_event_id = room_state_ids.get( + ("m.room.member", invite_event.sender) + ) + inviter_name = invite_event.sender + if inviter_member_event_id: + inviter_member_event = await self.store.get_event( + inviter_member_event_id, allow_none=True + ) + if inviter_member_event: + inviter_name = name_from_member_event(inviter_member_event) - # ...but we still refer to the 'reason' room which triggered the mail - if reason["room_name"] is not None: - return self.email_subjects.messages_in_room_and_others % { - "room": reason["room_name"], + if room_name is None: + return self.email_subjects.invite_from_person % { + "person": inviter_name, "app": self.app_name, } - else: - # If the reason room doesn't have a name, say who the messages - # are from explicitly to avoid, "messages in the Bob room" - room_id = reason["room_id"] - - sender_ids = list( - { - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - } - ) - member_events = await self.store.get_events( - [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] - ) + return self.email_subjects.invite_from_person_to_room % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + } + + if len(notifs) == 1: + # There is just the one notification, so give some detail + sender_name = None + event = notif_events[notifs[0]["event_id"]] + if ("m.room.member", event.sender) in room_state_ids: + state_event_id = room_state_ids[("m.room.member", event.sender)] + state_event = await self.store.get_event(state_event_id) + sender_name = name_from_member_event(state_event) + + if sender_name is not None and room_name is not None: + return self.email_subjects.message_from_person_in_room % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + } + elif sender_name is not None: + return self.email_subjects.message_from_person % { + "person": sender_name, + "app": self.app_name, + } - return self.email_subjects.messages_from_person_and_others % { - "person": descriptor_from_member_events(member_events.values()), + # The sender is unknown, just use the room name (or ID). + return self.email_subjects.messages_in_room % { + "room": room_name or room_id, + "app": self.app_name, + } + else: + # There's more than one notification for this room, so just + # say there are several + if room_name is not None: + return self.email_subjects.messages_in_room % { + "room": room_name, "app": self.app_name, } + return await self.make_summary_text_from_member_events( + room_id, notifs, room_state_ids, notif_events + ) + + async def make_summary_text( + self, + notifs_by_room: Dict[str, List[Dict[str, Any]]], + room_state_ids: Dict[str, StateMap[str]], + notif_events: Dict[str, EventBase], + reason: Dict[str, Any], + ) -> str: + """ + Make a summary text for the email when multiple rooms have notifications. + + Args: + notifs_by_room: A map of room ID to the notifications for that room. + room_state_ids: A map of room ID to the state map for that room. + notif_events: A map of event ID -> notification event. + reason: The reason this notification is being sent. + + Returns: + The summary text. + """ + # Stuff's happened in multiple different rooms + # ...but we still refer to the 'reason' room which triggered the mail + if reason["room_name"] is not None: + return self.email_subjects.messages_in_room_and_others % { + "room": reason["room_name"], + "app": self.app_name, + } + + room_id = reason["room_id"] + return await self.make_summary_text_from_member_events( + room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events + ) + + async def make_summary_text_from_member_events( + self, + room_id: str, + notifs: List[Dict[str, Any]], + room_state_ids: StateMap[str], + notif_events: Dict[str, EventBase], + ) -> str: + """ + Make a summary text for the email when only a single room has notifications. + + Args: + room_id: The ID of the room. + notifs: The notifications for this room. + room_state_ids: The state map for the room. + notif_events: A map of event ID -> notification event. + + Returns: + The summary text. + """ + # If the room doesn't have a name, say who the messages + # are from explicitly to avoid, "messages in the Bob room" + sender_ids = {notif_events[n["event_id"]].sender for n in notifs} + + member_events = await self.store.get_events( + [room_state_ids[("m.room.member", s)] for s in sender_ids] + ) + + # There was a single sender. + if len(sender_ids) == 1: + return self.email_subjects.messages_from_person % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } + + # There was more than one sender, use the first one and a tweaked template. + return self.email_subjects.messages_from_person_and_others % { + "person": descriptor_from_member_events(list(member_events.values())[:1]), + "app": self.app_name, + } + def make_room_link(self, room_id: str) -> str: if self.hs.config.email_riot_base_url: base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) @@ -668,6 +719,15 @@ class Mailer: def safe_markup(raw_html: str) -> jinja2.Markup: + """ + Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs. + + Args + raw_html: Unsafe HTML. + + Returns: + A Markup object ready to safely use in a Jinja template. + """ return jinja2.Markup( bleach.linkify( bleach.clean( @@ -684,8 +744,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup: def safe_text(raw_text: str) -> jinja2.Markup: """ - Process text: treat it as HTML but escape any tags (ie. just escape the - HTML) then linkify it. + Sanitise text (escape any HTML tags), and then linkify any bare URLs. + + Args + raw_text: Unsafe text which might include HTML markup. + + Returns: + A Markup object ready to safely use in a Jinja template. """ return jinja2.Markup( bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 7e50341d74..04c2c1482c 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -17,7 +17,7 @@ import logging import re from typing import TYPE_CHECKING, Dict, Iterable, Optional -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.types import StateMap @@ -63,7 +63,7 @@ async def calculate_room_name( m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) - if m_room_name and m_room_name.content and m_room_name.content["name"]: + if m_room_name and m_room_name.content and m_room_name.content.get("name"): return m_room_name.content["name"] # does it have a canonical alias? @@ -74,15 +74,11 @@ async def calculate_room_name( if ( canon_alias and canon_alias.content - and canon_alias.content["alias"] + and canon_alias.content.get("alias") and _looks_like_an_alias(canon_alias.content["alias"]) ): return canon_alias.content["alias"] - # at this point we're going to need to search the state by all state keys - # for an event type, so rearrange the data structure - room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) - if not fallback_to_members: return None @@ -94,7 +90,7 @@ async def calculate_room_name( if ( my_member_event is not None - and my_member_event.content["membership"] == "invite" + and my_member_event.content.get("membership") == Membership.INVITE ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: inviter_member_event = await store.get_event( @@ -111,6 +107,10 @@ async def calculate_room_name( else: return "Room Invite" + # at this point we're going to need to search the state by all state keys + # for an event type, so rearrange the data structure + room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) + # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: @@ -120,8 +120,8 @@ async def calculate_room_name( all_members = [ ev for ev in member_events.values() - if ev.content["membership"] == "join" - or ev.content["membership"] == "invite" + if ev.content.get("membership") == Membership.JOIN + or ev.content.get("membership") == Membership.INVITE ] # Sort the member events oldest-first so the we name people in the # order the joined (it should at least be deterministic rather than @@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str: def name_from_member_event(member_event: EventBase) -> str: - if ( - member_event.content - and "displayname" in member_event.content - and member_event.content["displayname"] - ): + if member_event.content and member_event.content.get("displayname"): return member_event.content["displayname"] return member_event.state_key diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644 index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 317796d5e0..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,6 +15,7 @@ # limitations under the License. import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Dict, @@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import ( TypingStream, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -88,7 +92,7 @@ class ReplicationCommandHandler: back out to connections. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastore() @@ -282,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -299,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - reactor=hs.get_reactor(), - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index bc6ba709a7..fdd087683b 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Type, cast import txredisapi @@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, + wrap_as_background_process, ) from synapse.replication.tcp.commands import ( Command, @@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): immediately after initialisation. Attributes: - handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to and publish from - (not anything to do with Synapse replication streams). - outbound_redis_connection: The connection to redis to use to send + synapse_handler: The command handler to handle incoming commands. + synapse_stream_name: The *redis* stream name to subscribe to and publish + from (not anything to do with Synapse replication streams). + synapse_outbound_redis_connection: The connection to redis to use to send commands. """ - handler = None # type: ReplicationCommandHandler - stream_name = None # type: str - outbound_redis_connection = None # type: txredisapi.RedisProtocol + synapse_handler = None # type: ReplicationCommandHandler + synapse_stream_name = None # type: str + synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. - logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) - await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name) + await make_deferred_yieldable(self.subscribe(self.synapse_stream_name)) logger.info( "Successfully subscribed to redis stream, sending REPLICATE command" ) - self.handler.new_connection(self) + self.synapse_handler.new_connection(self) await self._async_send_command(ReplicateCommand()) logger.info("REPLICATE successfully sent") # We send out our positions when there is a new connection in case the # other side missed updates. We do this for Redis connections as the # otherside won't know we've connected and so won't issue a REPLICATE. - self.handler.send_positions_to_connection(self) + self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. @@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): cmd: received command """ - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None) if not cmd_func: logger.warning("Unhandled command: %r", cmd) return @@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) - self.handler.lost_connection(self) + self.synapse_handler.lost_connection(self) # mark the logging context as finished self._logging_context.__exit__(None, None, None) @@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() await make_deferred_yieldable( - self.outbound_redis_connection.publish(self.stream_name, encoded_string) + self.synapse_outbound_redis_connection.publish( + self.synapse_stream_name, encoded_string + ) + ) + + +class SynapseRedisFactory(txredisapi.RedisFactory): + """A subclass of RedisFactory that periodically sends pings to ensure that + we detect dead connections. + """ + + def __init__( + self, + hs: "HomeServer", + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = txredisapi.ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: int = 30, + convertNumbers: Optional[int] = True, + ): + super().__init__( + uuid=uuid, + dbid=dbid, + poolsize=poolsize, + isLazy=isLazy, + handler=handler, + charset=charset, + password=password, + replyTimeout=replyTimeout, + convertNumbers=convertNumbers, ) + hs.get_clock().looping_call(self._send_ping, 30 * 1000) + + @wrap_as_background_process("redis_ping") + async def _send_ping(self): + for connection in self.pool: + try: + await make_deferred_yieldable(connection.ping()) + except Exception: + logger.warning("Failed to send ping to a redis connection") -class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + +class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately subscribes to a stream. @@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol ): - super().__init__() - - # This sets the password on the RedisFactory base class (as - # SubscriberFactory constructor doesn't pass it through). - self.password = hs.config.redis.redis_password + super().__init__( + hs, + uuid="subscriber", + dbid=None, + poolsize=1, + replyTimeout=30, + password=hs.config.redis.redis_password, + ) - self.handler = hs.get_tcp_replication() - self.stream_name = hs.hostname + self.synapse_handler = hs.get_tcp_replication() + self.synapse_stream_name = hs.hostname - self.outbound_redis_connection = outbound_redis_connection + self.synapse_outbound_redis_connection = outbound_redis_connection def buildProtocol(self, addr): - p = super().buildProtocol(addr) # type: RedisSubscriber + p = super().buildProtocol(addr) + p = cast(RedisSubscriber, p) # We do this here rather than add to the constructor of `RedisSubcriber` # as to do so would involve overriding `buildProtocol` entirely, however # the base method does some other things than just instantiating the # protocol. - p.handler = self.handler - p.outbound_redis_connection = self.outbound_redis_connection - p.stream_name = self.stream_name - p.password = self.password + p.synapse_handler = self.synapse_handler + p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection + p.synapse_stream_name = self.synapse_stream_name return p def lazyConnection( - reactor, + hs: "HomeServer", host: str = "localhost", port: int = 6379, dbid: Optional[int] = None, reconnect: bool = True, - charset: str = "utf-8", password: Optional[str] = None, - connectTimeout: Optional[int] = None, - replyTimeout: Optional[int] = None, - convertNumbers: bool = True, + replyTimeout: int = 30, ) -> txredisapi.RedisProtocol: - """Equivalent to `txredisapi.lazyConnection`, except allows specifying a - reactor. + """Creates a connection to Redis that is lazily set up and reconnects if the + connections is lost. """ - isLazy = True - poolsize = 1 - uuid = "%s:%d" % (host, port) - factory = txredisapi.RedisFactory( - uuid, - dbid, - poolsize, - isLazy, - txredisapi.ConnectionHandler, - charset, - password, - replyTimeout, - convertNumbers, + factory = SynapseRedisFactory( + hs, + uuid=uuid, + dbid=dbid, + poolsize=1, + isLazy=True, + handler=txredisapi.ConnectionHandler, + password=password, + replyTimeout=replyTimeout, ) factory.continueTrying = reconnect - for x in range(poolsize): - reactor.connectTCP(host, port, factory, connectTimeout) + + reactor = hs.get_reactor() + reactor.connectTCP(host, port, factory, 30) return factory.handler diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html index a75c73a142..c9bd4bef20 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -12,7 +12,7 @@ <header> <h1>That doesn't look right</h1> <p> - <strong>We were unable to validate your {{ server_name | e }} account</strong> + <strong>We were unable to validate your {{ server_name }} account</strong> via single sign‑on (SSO), because the SSO Identity Provider returned different details than when you logged in. </p> diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index d572ab87f7..790470fb59 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -12,7 +12,7 @@ <header> <h1>Confirm it's you to continue</h1> <p> - A client is trying to {{ description | e }}. To confirm this action + A client is trying to {{ description }}. To confirm this action re-authorize your account with single sign-on. </p> <p><strong> @@ -20,8 +20,8 @@ </strong></p> </header> <main> - <a href="{{ redirect_url | e }}" class="primary-button"/> - Continue with {{ idp.idp_name | e }} + <a href="{{ redirect_url }}" class="primary-button"/> + Continue with {{ idp.idp_name }} </a> </main> </body> diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index 69b93d65c1..b223ca0f56 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -22,7 +22,7 @@ <header> <h1>There was an error</h1> <p> - <strong id="errormsg">{{ error_description | e }}</strong> + <strong id="errormsg">{{ error_description }}</strong> </p> <p> If you are seeing this page after clicking a link sent to you via email, @@ -35,7 +35,7 @@ </p> <div id="error_code"> <p><strong>Error code</strong></p> - <p>{{ error | e }}</p> + <p>{{ error }}</p> </div> </header> diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index 5b38481012..62a640dad2 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -3,22 +3,22 @@ <head> <meta charset="UTF-8"> <link rel="stylesheet" href="/_matrix/static/client/login/style.css"> - <title>{{server_name | e}} Login</title> + <title>{{ server_name }} Login</title> </head> <body> <div id="container"> - <h1 id="title">{{server_name | e}} Login</h1> + <h1 id="title">{{ server_name }} Login</h1> <div class="login_flow"> <p>Choose one of the following identity providers:</p> <form> - <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}"> + <input type="hidden" name="redirectUrl" value="{{ redirect_url }}"> <ul class="radiobuttons"> {% for p in providers %} <li> - <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}"> - <label for="prov{{loop.index}}">{{p.idp_name | e}}</label> + <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}"> + <label for="prov{{ loop.index }}">{{ p.idp_name }}</label> {% if p.idp_icon %} - <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/> + <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/> {% endif %} </li> {% endfor %} diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index ce4f573848..d1328a6969 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -12,11 +12,11 @@ <header> {% if new_user %} <h1>Your account is now ready</h1> - <p>You've made your account on {{ server_name | e }}.</p> + <p>You've made your account on {{ server_name }}.</p> {% else %} <h1>Log in</h1> {% endif %} - <p>Continue to confirm you trust <strong>{{ display_url | e }}</strong>.</p> + <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p> </header> <main> {% if user_profile.avatar_url %} @@ -24,13 +24,13 @@ <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> <div class="profile-details"> {% if user_profile.display_name %} - <div class="display-name">{{ user_profile.display_name | e }}</div> + <div class="display-name">{{ user_profile.display_name }}</div> {% endif %} - <div class="user-id">{{ user_id | e }}</div> + <div class="user-id">{{ user_id }}</div> </div> </div> {% endif %} - <a href="{{ redirect_url | e }}" class="primary-button">Continue</a> + <a href="{{ redirect_url }}" class="primary-button">Continue</a> </main> </body> </html> diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6f7dc06503..57e0a10837 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd +# Copyright 2020, 2021 The Matrix.org Foundation C.I.C. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, + ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, MakeRoomAdminRestServlet, @@ -51,6 +54,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -230,6 +234,8 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) + ForwardExtremitiesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index ab7cc9102a..f14915d47e 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -431,7 +431,17 @@ class MakeRoomAdminRestServlet(RestServlet): if not admin_users: raise SynapseError(400, "No local admin user in room") - admin_user_id = admin_users[-1] + admin_user_id = None + + for admin_user in reversed(admin_users): + if room_state.get((EventTypes.Member, admin_user)): + admin_user_id = admin_user + break + + if not admin_user_id: + raise SynapseError( + 400, "No local admin user in room", + ) pl_content = power_levels.content else: @@ -499,3 +509,60 @@ class MakeRoomAdminRestServlet(RestServlet): ) return 200, {} + + +class ForwardExtremitiesRestServlet(RestServlet): + """Allows a server admin to get or clear forward extremities. + + Clearing does not require restarting the server. + + Clear forward extremities: + DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities + + Get forward_extremities: + GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.store = hs.get_datastore() + + async def resolve_room_id(self, room_identifier: str) -> str: + """Resolve to a room ID, if necessary.""" + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id + + async def on_DELETE(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + deleted_count = await self.store.delete_forward_extremities_for_room(room_id) + return 200, {"deleted": deleted_count} + + async def on_GET(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + extremities = await self.store.get_forward_extremities_for_room(room_id) + return 200, {"count": len(extremities), "results": extremities} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f39e3d6d5c..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret @@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 65e68d641b..a84a2fb385 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b093183e79..10e1891174 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 31a41e4a27..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -300,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -312,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -321,6 +323,7 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a632099167..bf3be653aa 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Check whether the URL should be downloaded as oEmbed content instead. - Params: + Args: url: The URL to check. Returns: @@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Request content from an oEmbed endpoint. - Params: + Args: endpoint: The oEmbed API endpoint. url: The URL to pass to the API. @@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource): def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: + """ + Calculate metadata for an HTML document. + + This uses lxml to parse the HTML document into the OG response. If errors + occur during processing of the document, an empty response is returned. + + Args: + body: The HTML document, as bytes. + media_url: The URI used to download the body. + request_encoding: The character encoding of the body, as a string. + + Returns: + The OG response as a dictionary. + """ # If there's no body, nothing useful is going to be found. if not body: return {} from lxml import etree + # Create an HTML parser. If this fails, log and return no metadata. try: parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body, parser) - og = _calc_og(tree, media_uri) + except LookupError: + # blindly consider the encoding as utf-8. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + except Exception as e: + logger.warning("Unable to create HTML parser: %s" % (e,)) + return {} + + def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(body_attempt, parser) + return _calc_og(tree, media_uri) + + # Attempt to parse the body. If this fails, log and return no metadata. + try: + return _attempt_calc_og(body) except UnicodeDecodeError: # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) - og = _calc_og(tree, media_uri) - - return og + return _attempt_calc_og(body.decode("utf-8", "ignore")) -def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: +def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d6880f2e6e..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,7 +16,7 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from twisted.web.http import Request @@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, @@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos, - ) -> dict: + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/synapse/server.py b/synapse/server.py index 9cdda83aa1..9bdd3177d7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.saml_handler import SamlHandler @@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_external_cache(self) -> ExternalCache: + return ExternalCache(self) + + @cache_in_self + def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: + if not self.config.redis.redis_enabled: + return None + + # We only want to import redis module if we're using it, as we have + # `txredisapi` as an optional dependency. + from synapse.replication.tcp.redis import lazyConnection + + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis_host, + self.config.redis_port, + ) + + return lazyConnection( + hs=self, + host=self.config.redis_host, + port=self.config.redis_port, + password=self.config.redis.redis_password, + reconnect=True, + ) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 84f59c7d85..3bd9ff8ca0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -310,6 +310,7 @@ class StateHandler: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None + entry = None else: # otherwise, we'll need to resolve the state across the prev_events. @@ -340,9 +341,13 @@ class StateHandler: current_state_ids=state_ids_before_event, ) - # XXX: can we update the state cache entry for the new state group? or - # could we set a flag on resolve_state_groups_for_events to tell it to - # always make a state group? + # Assign the new state group to the cached state entry. + # + # Note that this can race in that we could generate multiple state + # groups for the same state entry, but that is just inefficient + # rather than dangerous. + if entry and entry.state_group is None: + entry.state_group = state_group_before_event # # now if it's not a state event, we're done diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c7220bc778..d2ba4bd2fc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -262,6 +262,12 @@ class LoggingTransaction: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: + """Similar to `executemany`, except `txn.rowcount` will not be correct + afterwards. + + More efficient than `executemany` on PostgreSQL + """ + if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch # type: ignore diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index ae561a2da3..5d0845588c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore +from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -118,6 +119,7 @@ class DataStore( UIAuthStore, CacheInvalidationWorkerStore, ServerMetricsStore, + EventForwardExtremitiesStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9097677648..659d8f245f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) logger.info("Pruned %d device list outbound pokes", count) @@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Delete older entries in the table, as we really only care about # when the latest change happened. - txn.executemany( + txn.execute_batch( """ DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c128889bf9..309f1e865b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, dict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + ) -> Dict[str, Dict[str, Dict[str, str]]]: """Take a list of one time keys out of the database. Args: diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 1b657191a9..438383abe1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): VALUES (?, ?, ?, ?, ?, ?) """ - txn.executemany( + txn.execute_batch( sql, ( _gen_entry(user_id, actions) @@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ], ) - txn.executemany( + txn.execute_batch( """ UPDATE event_push_summary SET notif_count = ?, unread_count = ?, stream_ordering = ? diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5db7d7aaa8..ccda9f1caa 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -473,8 +473,9 @@ class PersistEventsStore: txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, ) - @staticmethod + @classmethod def _add_chain_cover_index( + cls, txn, db_pool: DatabasePool, event_to_room_id: Dict[str, str], @@ -614,60 +615,17 @@ class PersistEventsStore: if not events_to_calc_chain_id_for: return - # We now calculate the chain IDs/sequence numbers for the events. We - # do this by looking at the chain ID and sequence number of any auth - # event with the same type/state_key and incrementing the sequence - # number by one. If there was no match or the chain ID/sequence - # number is already taken we generate a new chain. - # - # We need to do this in a topologically sorted order as we want to - # generate chain IDs/sequence numbers of an event's auth events - # before the event itself. - chains_tuples_allocated = set() # type: Set[Tuple[int, int]] - new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] - for event_id in sorted_topologically( - events_to_calc_chain_id_for, event_to_auth_chain - ): - existing_chain_id = None - for auth_id in event_to_auth_chain.get(event_id, []): - if event_to_types.get(event_id) == event_to_types.get(auth_id): - existing_chain_id = chain_map[auth_id] - break - - new_chain_tuple = None - if existing_chain_id: - # We found a chain ID/sequence number candidate, check its - # not already taken. - proposed_new_id = existing_chain_id[0] - proposed_new_seq = existing_chain_id[1] + 1 - if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: - already_allocated = db_pool.simple_select_one_onecol_txn( - txn, - table="event_auth_chains", - keyvalues={ - "chain_id": proposed_new_id, - "sequence_number": proposed_new_seq, - }, - retcol="event_id", - allow_none=True, - ) - if already_allocated: - # Mark it as already allocated so we don't need to hit - # the DB again. - chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) - else: - new_chain_tuple = ( - proposed_new_id, - proposed_new_seq, - ) - - if not new_chain_tuple: - new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) - - chains_tuples_allocated.add(new_chain_tuple) - - chain_map[event_id] = new_chain_tuple - new_chain_tuples[event_id] = new_chain_tuple + # Allocate chain ID/sequence numbers to each new event. + new_chain_tuples = cls._allocate_chain_ids( + txn, + db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, + events_to_calc_chain_id_for, + chain_map, + ) + chain_map.update(new_chain_tuples) db_pool.simple_insert_many_txn( txn, @@ -794,6 +752,137 @@ class PersistEventsStore: ], ) + @staticmethod + def _allocate_chain_ids( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + events_to_calc_chain_id_for: Set[str], + chain_map: Dict[str, Tuple[int, int]], + ) -> Dict[str, Tuple[int, int]]: + """Allocates, but does not persist, chain ID/sequence numbers for the + events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index + for info on args) + """ + + # We now calculate the chain IDs/sequence numbers for the events. We do + # this by looking at the chain ID and sequence number of any auth event + # with the same type/state_key and incrementing the sequence number by + # one. If there was no match or the chain ID/sequence number is already + # taken we generate a new chain. + # + # We try to reduce the number of times that we hit the database by + # batching up calls, to make this more efficient when persisting large + # numbers of state events (e.g. during joins). + # + # We do this by: + # 1. Calculating for each event which auth event will be used to + # inherit the chain ID, i.e. converting the auth chain graph to a + # tree that we can allocate chains on. We also keep track of which + # existing chain IDs have been referenced. + # 2. Fetching the max allocated sequence number for each referenced + # existing chain ID, generating a map from chain ID to the max + # allocated sequence number. + # 3. Iterating over the tree and allocating a chain ID/seq no. to the + # new event, by incrementing the sequence number from the + # referenced event's chain ID/seq no. and checking that the + # incremented sequence number hasn't already been allocated (by + # looking in the map generated in the previous step). We generate a + # new chain if the sequence number has already been allocated. + # + + existing_chains = set() # type: Set[int] + tree = [] # type: List[Tuple[str, Optional[str]]] + + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events before + # the event itself. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + for auth_id in event_to_auth_chain.get(event_id, []): + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map.get(auth_id) + if existing_chain_id: + existing_chains.add(existing_chain_id[0]) + + tree.append((event_id, auth_id)) + break + else: + tree.append((event_id, None)) + + # Fetch the current max sequence number for each existing referenced chain. + sql = """ + SELECT chain_id, MAX(sequence_number) FROM event_auth_chains + WHERE %s + GROUP BY chain_id + """ + clause, args = make_in_list_sql_clause( + db_pool.engine, "chain_id", existing_chains + ) + txn.execute(sql % (clause,), args) + + chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + + # Allocate the new events chain ID/sequence numbers. + # + # To reduce the number of calls to the database we don't allocate a + # chain ID number in the loop, instead we use a temporary `object()` for + # each new chain ID. Once we've done the loop we generate the necessary + # number of new chain IDs in one call, replacing all temporary + # objects with real allocated chain IDs. + + unallocated_chain_ids = set() # type: Set[object] + new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + for event_id, auth_event_id in tree: + # If we reference an auth_event_id we fetch the allocated chain ID, + # either from the existing `chain_map` or the newly generated + # `new_chain_tuples` map. + existing_chain_id = None + if auth_event_id: + existing_chain_id = new_chain_tuples.get(auth_event_id) + if not existing_chain_id: + existing_chain_id = chain_map[auth_event_id] + + new_chain_tuple = None # type: Optional[Tuple[Any, int]] + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + + if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + # If we need to start a new chain we allocate a temporary chain ID. + if not new_chain_tuple: + new_chain_tuple = (object(), 1) + unallocated_chain_ids.add(new_chain_tuple[0]) + + new_chain_tuples[event_id] = new_chain_tuple + chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] + + # Generate new chain IDs for all unallocated chain IDs. + newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + txn, len(unallocated_chain_ids) + ) + + # Map from potentially temporary chain ID to real chain ID + chain_id_to_allocated_map = dict( + zip(unallocated_chain_ids, newly_allocated_chain_ids) + ) # type: Dict[Any, int] + chain_id_to_allocated_map.update((c, c) for c in existing_chains) + + return { + event_id: (chain_id_to_allocated_map[chain_id], seq) + for event_id, (chain_id, seq) in new_chain_tuples.items() + } + def _persist_transaction_ids_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index e46e44ba54..5ca4fa6817 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_txn(txn): sql = ( "SELECT stream_ordering, event_id, json FROM events" @@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, update_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_search_txn(txn): sql = ( "SELECT stream_ordering, event_id FROM events" @@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, rows_to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py new file mode 100644 index 0000000000..0ac1da9c35 --- /dev/null +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class EventForwardExtremitiesStore(SQLBaseStore): + async def delete_forward_extremities_for_room(self, room_id: str) -> int: + """Delete any extra forward extremities for a room. + + Invalidates the "get_latest_event_ids_in_room" cache if any forward + extremities were deleted. + + Returns count deleted. + """ + + def delete_forward_extremities_for_room_txn(txn): + # First we need to get the event_id to not delete + sql = """ + SELECT event_id FROM event_forward_extremities + INNER JOIN events USING (room_id, event_id) + WHERE room_id = ? + ORDER BY stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (room_id,)) + rows = txn.fetchall() + try: + event_id = rows[0][0] + logger.debug( + "Found event_id %s as the forward extremity to keep for room %s", + event_id, + room_id, + ) + except KeyError: + msg = "No forward extremity event found for room %s" % room_id + logger.warning(msg) + raise SynapseError(400, msg) + + # Now delete the extra forward extremities + sql = """ + DELETE FROM event_forward_extremities + WHERE event_id != ? AND room_id = ? + """ + + txn.execute(sql, (event_id, room_id)) + logger.info( + "Deleted %s extra forward extremities for room %s", + txn.rowcount, + room_id, + ) + + if txn.rowcount > 0: + # Invalidate the cache + self._invalidate_cache_and_stream( + txn, self.get_latest_event_ids_in_room, (room_id,), + ) + + return txn.rowcount + + return await self.db_pool.runInteraction( + "delete_forward_extremities_for_room", + delete_forward_extremities_for_room_txn, + ) + + async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + """Get list of forward extremities for a room.""" + + def get_forward_extremities_for_room_txn(txn): + sql = """ + SELECT event_id, state_group, depth, received_ts + FROM event_forward_extremities + INNER JOIN event_to_state_groups USING (event_id) + INNER JOIN events USING (room_id, event_id) + WHERE room_id = ? + """ + + txn.execute(sql, (room_id,)) + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, + ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 283c8a5e22..e017177655 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_origin = ? AND media_id = ?" ) - txn.executemany( + txn.execute_batch( sql, ( (time_ms, media_origin, media_id) @@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_id = ?" ) - txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn @@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" def _delete_url_cache_txn(txn): - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache", _delete_url_cache_txn @@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_media_txn(txn): sql = "DELETE FROM local_media_repository WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index ab18cc4d79..92e65aa640 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (x[0] - 1) * x[1] for x in res if x[1] ) + async def count_daily_e2ee_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.encrypted' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) + + async def count_daily_sent_e2ee_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.encrypted' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction( + "count_daily_sent_e2ee_messages", _count_messages + ) + + async def count_daily_active_e2ee_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.encrypted' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction( + "count_daily_active_e2ee_rooms", _count + ) + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 5d668aadb2..ecfc9f20b1 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): ) # Update backward extremeties - txn.executemany( + txn.execute_batch( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", [(room_id, event_id) for event_id, in new_backwards_extrems], diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bc7621b8d6..2687ef3e43 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self.db_pool.simple_delete_one_txn( + # It is expected that there is exactly one pusher to delete, but + # if it isn't there (or there are multiple) delete them all. + self.db_pool.simple_delete_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 14c0878d81..8405dd460f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: + """Sets whether a user shadow-banned. + + Args: + user: user ID of the user to test + shadow_banned: true iff the user is to be shadow-banned, false otherwise. + """ + + def set_shadow_banned_txn(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"shadow_banned": shadow_banned}, + ) + # In order for this to apply immediately, clear the cache for this user. + tokens = self.db_pool.simple_select_onecol_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user.to_string()}, + retcol="token", + ) + for token in tokens: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, @@ -1124,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): FROM user_threepids """ - txn.executemany(sql, [(id_server,) for id_server in id_servers]) + txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) if id_servers: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index dcdaf09682..92382bed28 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): "max_stream_id_exclusive", self._stream_order_on_start + 1 ) - INSERT_CLUMP_SIZE = 1000 - def add_membership_profile_txn(txn): sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json @@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): UPDATE room_memberships SET display_name = ?, avatar_url = ? WHERE event_id = ? AND room_id = ? """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) + txn.execute_batch(to_update_sql, to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py index f35c70b699..9e8f35c1d2 100644 --- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py +++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py @@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs # { "ignored_users": "@someone:example.org": {} } ignored_users = content.get("ignored_users", {}) if isinstance(ignored_users, dict) and ignored_users: - cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) + cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users]) # Add indexes after inserting data for efficiency. logger.info("Adding constraints to ignored_users table") diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e34fce6281..f5e7d9ef98 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import Collection logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( @@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore): async def search_rooms( self, - room_ids: List[str], + room_ids: Collection[str], search_term: str, keys: List[str], limit, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 0cdb3ec1f7..d421d18f8d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -15,11 +15,12 @@ # limitations under the License. import logging -from collections import Counter from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Tuple +from typing_extensions import Counter + from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: + async def get_earliest_token_for_stats( + self, stats_type: str, id: str + ) -> Optional[int]: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore): ) async def bulk_update_stats_delta( - self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. @@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore): async def get_changes_room_total_events_and_bytes( self, min_pos: int, max_pos: int - ) -> Dict[str, Dict[str, int]]: + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Fetches the counts of events in the given range of stream IDs. Args: @@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore): max_pos, ) - def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + def get_changes_room_total_events_and_bytes_txn( + self, txn, low_pos: int, high_pos: int + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Gets the total_events and total_event_bytes counts for rooms and senders, in a range of stream_orderings (including backfilled events). Args: txn - low_pos (int): Low stream ordering - high_pos (int): High stream ordering + low_pos: Low stream ordering + high_pos: High stream ordering Returns: - tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The - room and user deltas for total_events/total_event_bytes in the + The room and user deltas for total_events/total_event_bytes in the format of `stats_id` -> fields """ diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ef11f1c3b3..7b9729da09 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - async def update_user_directory_stream_pos(self, stream_id: str) -> None: + async def update_user_directory_stream_pos(self, stream_id: int) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0e31cc811a..89cdc84a9c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) logger.info("[purge] removing redundant state groups") - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups_state WHERE state_group = ?", ((sg,) for sg in state_groups_to_delete), ) - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups WHERE id = ?", ((sg,) for sg in state_groups_to_delete), ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index bb84c0d792..71ef5a72dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -15,12 +15,11 @@ import heapq import logging import threading -from collections import deque +from collections import OrderedDict from contextlib import contextmanager from typing import Dict, List, Optional, Set, Tuple, Union import attr -from typing_extensions import Deque from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction @@ -101,7 +100,13 @@ class StreamIdGenerator: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() # type: Deque[int] + + # We use this as an ordered set, as we want to efficiently append items, + # remove items and get the first item. Since we insert IDs in order, the + # insertion ordering will ensure its in the correct ordering. + # + # The key and values are the same, but we never look at the values. + self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] def get_next(self): """ @@ -113,7 +118,7 @@ class StreamIdGenerator: self._current += self._step next_id = self._current - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -121,7 +126,7 @@ class StreamIdGenerator: yield next_id finally: with self._lock: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -140,7 +145,7 @@ class StreamIdGenerator: self._current += n * self._step for next_id in next_ids: - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -149,7 +154,7 @@ class StreamIdGenerator: finally: with self._lock: for next_id in next_ids: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -162,7 +167,7 @@ class StreamIdGenerator: """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._step + return next(iter(self._unfinished_ids)) - self._step return self._current diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index c780ade077..0ec4dc2918 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): ... @abc.abstractmethod + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + """Get the next `n` IDs in the sequence""" + ... + + @abc.abstractmethod def check_consistency( self, db_conn: "LoggingDatabaseConnection", @@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + first_id = self._current_max_id + 1 + self._current_max_id += n + return [first_id + i for i in range(n)] + def check_consistency( self, db_conn: Connection, diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 1ee61851e4..09b094ded7 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: module = importlib.import_module(module) provider_class = getattr(module, clz) - module_config = provider.get("config") + # Load the module config. If None, pass an empty dictionary instead + module_config = provider.get("config") or {} try: provider_config = provider_class.parse_config(module_config) except jsonschema.ValidationError as e: |