summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/config/_base.py10
-rw-r--r--synapse/config/saml2_config.py4
-rw-r--r--synapse/federation/federation_client.py15
-rw-r--r--synapse/federation/federation_server.py52
-rw-r--r--synapse/federation/transport/server.py13
-rw-r--r--synapse/handlers/federation.py65
-rw-r--r--synapse/handlers/pagination.py8
-rw-r--r--synapse/rest/client/versions.py19
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/databases/main/stream.py13
11 files changed, 157 insertions, 49 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index bf0bf192a5..e40b582bd5 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.20.0rc3"
+__version__ = "1.20.1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index ad5ab6ad62..f8ab8e38df 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -194,7 +194,10 @@ class Config:
             return file_stream.read()
 
     def read_templates(
-        self, filenames: List[str], custom_template_directory: Optional[str] = None,
+        self,
+        filenames: List[str],
+        custom_template_directory: Optional[str] = None,
+        autoescape: bool = False,
     ) -> List[jinja2.Template]:
         """Load a list of template files from disk using the given variables.
 
@@ -210,6 +213,9 @@ class Config:
             custom_template_directory: A directory to try to look for the templates
                 before using the default Synapse template directory instead.
 
+            autoescape: Whether to autoescape variables before inserting them into the
+                template.
+
         Raises:
             ConfigError: if the file's path is incorrect or otherwise cannot be read.
 
@@ -233,7 +239,7 @@ class Config:
             search_directories.insert(0, custom_template_directory)
 
         loader = jinja2.FileSystemLoader(search_directories)
-        env = jinja2.Environment(loader=loader, autoescape=True)
+        env = jinja2.Environment(loader=loader, autoescape=autoescape)
 
         # Update the environment with our custom filters
         env.filters.update(
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index cc7401888b..755478e2ff 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -169,8 +169,10 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "15m")
         )
 
+        # We enable autoescape here as the message may potentially come from a
+        # remote resource
         self.saml2_error_html_template = self.read_templates(
-            ["saml_error.html"], saml2_config.get("template_dir")
+            ["saml_error.html"], saml2_config.get("template_dir"), autoescape=True
         )[0]
 
     def _default_saml_config_dict(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 38ac7ec699..d42930d1b9 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -54,7 +54,7 @@ from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.logging.utils import log_function
-from synapse.types import JsonDict
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -217,11 +217,9 @@ class FederationClient(FederationBase):
             for p in transaction_data["pdus"]
         ]
 
-        # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = await make_deferred_yieldable(
-            defer.gatherResults(
-                self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
-            ).addErrback(unwrapFirstError)
+        # Check signatures and hash of pdus, removing any from the list that fail checks
+        pdus[:] = await self._check_sigs_and_hash_and_fetch(
+            dest, pdus, outlier=True, room_version=room_version
         )
 
         return pdus
@@ -386,10 +384,11 @@ class FederationClient(FederationBase):
                     pdu.event_id, allow_rejected=True, allow_none=True
                 )
 
-            if not res and pdu.origin != origin:
+            pdu_origin = get_domain_from_id(pdu.sender)
+            if not res and pdu_origin != origin:
                 try:
                     res = await self.get_pdu(
-                        destinations=[pdu.origin],
+                        destinations=[pdu_origin],
                         event_id=pdu.event_id,
                         room_version=room_version,
                         outlier=outlier,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 218df884b0..ff00f0b302 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -97,10 +97,16 @@ class FederationServer(FederationBase):
         self.state = hs.get_state_handler()
 
         self.device_handler = hs.get_device_handler()
+        self._federation_ratelimiter = hs.get_federation_ratelimiter()
 
         self._server_linearizer = Linearizer("fed_server")
         self._transaction_linearizer = Linearizer("fed_txn_handler")
 
+        # We cache results for transaction with the same ID
+        self._transaction_resp_cache = ResponseCache(
+            hs, "fed_txn_handler", timeout_ms=30000
+        )
+
         self.transaction_actions = TransactionActions(self.store)
 
         self.registry = hs.get_federation_registry()
@@ -135,22 +141,44 @@ class FederationServer(FederationBase):
         request_time = self._clock.time_msec()
 
         transaction = Transaction(**transaction_data)
+        transaction_id = transaction.transaction_id  # type: ignore
 
-        if not transaction.transaction_id:  # type: ignore
+        if not transaction_id:
             raise Exception("Transaction missing transaction_id")
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)  # type: ignore
+        logger.debug("[%s] Got transaction", transaction_id)
 
-        # use a linearizer to ensure that we don't process the same transaction
-        # multiple times in parallel.
-        with (
-            await self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id)  # type: ignore
-            )
-        ):
-            result = await self._handle_incoming_transaction(
-                origin, transaction, request_time
-            )
+        # We wrap in a ResponseCache so that we de-duplicate retried
+        # transactions.
+        return await self._transaction_resp_cache.wrap(
+            (origin, transaction_id),
+            self._on_incoming_transaction_inner,
+            origin,
+            transaction,
+            request_time,
+        )
+
+    async def _on_incoming_transaction_inner(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Tuple[int, Dict[str, Any]]:
+        # Use a linearizer to ensure that transactions from a remote are
+        # processed in order.
+        with await self._transaction_linearizer.queue(origin):
+            # We rate limit here *after* we've queued up the incoming requests,
+            # so that we don't fill up the ratelimiter with blocked requests.
+            #
+            # This is important as the ratelimiter allows N concurrent requests
+            # at a time, and only starts ratelimiting if there are more requests
+            # than that being processed at a time. If we queued up requests in
+            # the linearizer/response cache *after* the ratelimiting then those
+            # queued up requests would count as part of the allowed limit of N
+            # concurrent requests.
+            with self._federation_ratelimiter.ratelimit(origin) as d:
+                await d
+
+                result = await self._handle_incoming_transaction(
+                    origin, transaction, request_time
+                )
 
         return result
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 9325e0f857..cc7e9a973b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -45,7 +45,6 @@ from synapse.logging.opentracing import (
 )
 from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
@@ -72,9 +71,7 @@ class TransportLayerServer(JsonResource):
         super(TransportLayerServer, self).__init__(hs, canonical_json=False)
 
         self.authenticator = Authenticator(hs)
-        self.ratelimiter = FederationRateLimiter(
-            self.clock, config=hs.config.rc_federation
-        )
+        self.ratelimiter = hs.get_federation_ratelimiter()
 
         self.register_servlets()
 
@@ -272,6 +269,8 @@ class BaseFederationServlet:
 
     PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version
 
+    RATELIMIT = True  # Whether to rate limit requests or not
+
     def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
@@ -335,7 +334,7 @@ class BaseFederationServlet:
                 )
 
             with scope:
-                if origin:
+                if origin and self.RATELIMIT:
                     with ratelimiter.ratelimit(origin) as d:
                         await d
                         if request._disconnected:
@@ -372,6 +371,10 @@ class BaseFederationServlet:
 class FederationSendServlet(BaseFederationServlet):
     PATH = "/send/(?P<transaction_id>[^/]*)/?"
 
+    # We ratelimit manually in the handler as we queue up the requests and we
+    # don't want to fill up the ratelimiter with blocked requests.
+    RATELIMIT = False
+
     def __init__(self, handler, server_name, **kwargs):
         super(FederationSendServlet, self).__init__(
             handler, server_name=server_name, **kwargs
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 43f2986f89..014dab2940 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -943,15 +943,26 @@ class FederationHandler(BaseHandler):
 
         return events
 
-    async def maybe_backfill(self, room_id, current_depth):
+    async def maybe_backfill(
+        self, room_id: str, current_depth: int, limit: int
+    ) -> bool:
         """Checks the database to see if we should backfill before paginating,
         and if so do.
+
+        Args:
+            room_id
+            current_depth: The depth from which we're paginating from. This is
+                used to decide if we should backfill and what extremities to
+                use.
+            limit: The number of events that the pagination request will
+                return. This is used as part of the heuristic to decide if we
+                should back paginate.
         """
         extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
-            return
+            return False
 
         # We only want to paginate if we can actually see the events we'll get,
         # as otherwise we'll just spend a lot of resources to get redacted
@@ -1004,16 +1015,54 @@ class FederationHandler(BaseHandler):
         sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
         max_depth = sorted_extremeties_tuple[0][1]
 
+        # If we're approaching an extremity we trigger a backfill, otherwise we
+        # no-op.
+        #
+        # We chose twice the limit here as then clients paginating backwards
+        # will send pagination requests that trigger backfill at least twice
+        # using the most recent extremity before it gets removed (see below). We
+        # chose more than one times the limit in case of failure, but choosing a
+        # much larger factor will result in triggering a backfill request much
+        # earlier than necessary.
+        if current_depth - 2 * limit > max_depth:
+            logger.debug(
+                "Not backfilling as we don't need to. %d < %d - 2 * %d",
+                max_depth,
+                current_depth,
+                limit,
+            )
+            return False
+
+        logger.debug(
+            "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
+            room_id,
+            current_depth,
+            max_depth,
+            sorted_extremeties_tuple,
+        )
+
+        # We ignore extremities that have a greater depth than our current depth
+        # as:
+        #    1. we don't really care about getting events that have happened
+        #       before our current position; and
+        #    2. we have likely previously tried and failed to backfill from that
+        #       extremity, so to avoid getting "stuck" requesting the same
+        #       backfill repeatedly we drop those extremities.
+        filtered_sorted_extremeties_tuple = [
+            t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
+        ]
+
+        # However, we need to check that the filtered extremities are non-empty.
+        # If they are empty then either we can a) bail or b) still attempt to
+        # backill. We opt to try backfilling anyway just in case we do get
+        # relevant events.
+        if filtered_sorted_extremeties_tuple:
+            sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
+
         # We don't want to specify too many extremities as it causes the backfill
         # request URI to be too long.
         extremities = dict(sorted_extremeties_tuple[:5])
 
-        if current_depth > max_depth:
-            logger.debug(
-                "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
-            )
-            return
-
         # Now we need to decide which hosts to hit first.
 
         # First we try hosts that are already in the room
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 34ed0e2921..6067585f9b 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -362,9 +362,9 @@ class PaginationHandler:
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
-                    max_topo = room_token.topological
+                    curr_topo = room_token.topological
                 else:
-                    max_topo = await self.store.get_max_topological_token(
+                    curr_topo = await self.store.get_current_topological_token(
                         room_id, room_token.stream
                     )
 
@@ -380,11 +380,11 @@ class PaginationHandler:
                     leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    if RoomStreamToken.parse(leave_token).topological < max_topo:
+                    if RoomStreamToken.parse(leave_token).topological < curr_topo:
                         source_config.from_key = str(leave_token)
 
                 await self.hs.get_handlers().federation_handler.maybe_backfill(
-                    room_id, max_topo
+                    room_id, curr_topo, limit=source_config.limit,
                 )
 
             events, next_key = await self.store.paginate_room_events(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 24ac57f35d..c560edbc59 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -19,6 +19,7 @@
 import logging
 import re
 
+from synapse.api.constants import RoomCreationPreset
 from synapse.http.servlet import RestServlet
 
 logger = logging.getLogger(__name__)
@@ -31,6 +32,20 @@ class VersionsRestServlet(RestServlet):
         super(VersionsRestServlet, self).__init__()
         self.config = hs.config
 
+        # Calculate these once since they shouldn't change after start-up.
+        self.e2ee_forced_public = (
+            RoomCreationPreset.PUBLIC_CHAT
+            in self.config.encryption_enabled_by_default_for_room_presets
+        )
+        self.e2ee_forced_private = (
+            RoomCreationPreset.PRIVATE_CHAT
+            in self.config.encryption_enabled_by_default_for_room_presets
+        )
+        self.e2ee_forced_trusted_private = (
+            RoomCreationPreset.TRUSTED_PRIVATE_CHAT
+            in self.config.encryption_enabled_by_default_for_room_presets
+        )
+
     def on_GET(self, request):
         return (
             200,
@@ -62,6 +77,10 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.msc2432": True,
                     # Implements additional endpoints as described in MSC2666
                     "uk.half-shot.msc2666": True,
+                    # Whether new rooms will be set to encrypted or not (based on presets).
+                    "io.element.e2ee_forced.public": self.e2ee_forced_public,
+                    "io.element.e2ee_forced.private": self.e2ee_forced_private,
+                    "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
                 },
             },
         )
diff --git a/synapse/server.py b/synapse/server.py
index 9055b97ac3..5e3752c333 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -114,6 +114,7 @@ from synapse.streams.events import EventSources
 from synapse.types import DomainSpecificString
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
+from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import random_string
 
 logger = logging.getLogger(__name__)
@@ -642,6 +643,10 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_replication_streams(self) -> Dict[str, Stream]:
         return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
 
+    @cache_in_self
+    def get_federation_ratelimiter(self) -> FederationRateLimiter:
+        return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index be6df8a6d1..db20a3db30 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -648,23 +648,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
 
-    async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
-        """Get the max topological token in a room before the given stream
+    async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
+        """Gets the topological token in a room after or at the given stream
         ordering.
 
         Args:
             room_id
             stream_key
-
-        Returns:
-            The maximum topological token.
         """
         sql = (
-            "SELECT coalesce(max(topological_ordering), 0) FROM events"
-            " WHERE room_id = ? AND stream_ordering < ?"
+            "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
+            " WHERE room_id = ? AND stream_ordering >= ?"
         )
         row = await self.db_pool.execute(
-            "get_max_topological_token", None, sql, room_id, stream_key
+            "get_current_topological_token", None, sql, room_id, stream_key
         )
         return row[0][0] if row else 0