summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/app/generic_worker.py1
-rw-r--r--synapse/config/federation.py40
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/federation/federation_server.py1
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/identity.py6
-rw-r--r--synapse/http/client.py46
-rw-r--r--synapse/http/federation/matrix_federation_agent.py16
-rw-r--r--synapse/http/matrixfederationclient.py26
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/rest/admin/users.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/server.py36
-rw-r--r--synapse/state/v2.py87
16 files changed, 202 insertions, 74 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 65c1f5aa3f..d33a99f230 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.23.0"
+__version__ = "1.24.0rc1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890aa..aa12c74358 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -266,7 +266,6 @@ class GenericWorkerPresence(BasePresenceHandler):
         super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
-        self.http_client = hs.get_simple_http_client()
 
         self._presence_enabled = hs.config.use_presence
 
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca54e..27ccf61c3c 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -36,22 +36,30 @@ class FederationConfig(Config):
             for domain in federation_domain_whitelist:
                 self.federation_domain_whitelist[domain] = True
 
-        self.federation_ip_range_blacklist = config.get(
-            "federation_ip_range_blacklist", []
-        )
+        ip_range_blacklist = config.get("ip_range_blacklist", [])
 
         # Attempt to create an IPSet from the given ranges
         try:
-            self.federation_ip_range_blacklist = IPSet(
-                self.federation_ip_range_blacklist
-            )
-
-            # Always blacklist 0.0.0.0, ::
-            self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+            self.ip_range_blacklist = IPSet(ip_range_blacklist)
+        except Exception as e:
+            raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e)
+        # Always blacklist 0.0.0.0, ::
+        self.ip_range_blacklist.update(["0.0.0.0", "::"])
+
+        # The federation_ip_range_blacklist is used for backwards-compatibility
+        # and only applies to federation and identity servers. If it is not given,
+        # default to ip_range_blacklist.
+        federation_ip_range_blacklist = config.get(
+            "federation_ip_range_blacklist", ip_range_blacklist
+        )
+        try:
+            self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
         except Exception as e:
             raise ConfigError(
                 "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
             )
+        # Always blacklist 0.0.0.0, ::
+        self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
 
         federation_metrics_domains = config.get("federation_metrics_domains") or []
         validate_config(
@@ -76,17 +84,19 @@ class FederationConfig(Config):
         #  - nyc.example.com
         #  - syd.example.com
 
-        # Prevent federation requests from being sent to the following
-        # blacklist IP address CIDR ranges. If this option is not specified, or
-        # specified with an empty list, no ip range blacklist will be enforced.
+        # Prevent outgoing requests from being sent to the following blacklisted IP address
+        # CIDR ranges. If this option is not specified, or specified with an empty list,
+        # no IP range blacklist will be enforced.
         #
-        # As of Synapse v1.4.0 this option also affects any outbound requests to identity
-        # servers provided by user input.
+        # The blacklist applies to the outbound requests for federation, identity servers,
+        # push servers, and for checking key validitity for third-party invite events.
         #
         # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
         # listed here, since they correspond to unroutable addresses.)
         #
-        federation_ip_range_blacklist:
+        # This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
+        #
+        ip_range_blacklist:
           - '127.0.0.0/8'
           - '10.0.0.0/8'
           - '172.16.0.0/12'
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77cf9..f23eacc0d7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -578,7 +578,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
     def __init__(self, hs):
         super().__init__(hs)
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
     async def get_keys(self, keys_to_fetch):
@@ -748,7 +748,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
     def __init__(self, hs):
         super().__init__(hs)
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
 
     async def get_keys(self, keys_to_fetch):
         """
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4b6ab470d0..35e345ce70 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -845,7 +845,6 @@ class FederationHandlerRegistry:
 
     def __init__(self, hs: "HomeServer"):
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
         self.clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..abe9168c78 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient:
 
     def __init__(self, hs):
         self.server_name = hs.hostname
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
 
     @log_function
     def get_room_state_ids(self, destination, room_id, event_id):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..df82e60b33 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
         self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self._instance_name = hs.get_instance_name()
         self._replication = hs.get_replication_data_handler()
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
     def __init__(self, hs):
         super().__init__(hs)
 
+        # An HTTP client for contacting trusted URLs.
         self.http_client = SimpleHttpClient(hs)
-        # We create a blacklisting instance of SimpleHttpClient for contacting identity
-        # servers specified by clients
+        # An HTTP client for contacting identity servers specified by clients.
         self.blacklisting_http_client = SimpleHttpClient(
             hs, ip_blacklist=hs.config.federation_ip_range_blacklist
         )
-        self.federation_http_client = hs.get_http_client()
+        self.federation_http_client = hs.get_federation_http_client()
         self.hs = hs
 
     async def threepid_from_creds(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593f2..df7730078f 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
     return _scheduler
 
 
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
     """
     A proxy for reactor.nameResolver which only produces non-blacklisted IP
     addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
         return r
 
 
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+    """
+    A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+    addresses, to prevent DNS rebinding.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        self._reactor = reactor
+
+        # We need to use a DNS resolver which filters out blacklisted IP
+        # addresses, to prevent DNS rebinding.
+        self._nameResolver = _IPBlacklistingResolver(
+            self._reactor, ip_whitelist, ip_blacklist
+        )
+
+    def __getattr__(self, attr: str) -> Any:
+        # Passthrough to the real reactor except for the DNS resolver.
+        if attr == "nameResolver":
+            return self._nameResolver
+        else:
+            return getattr(self._reactor, attr)
+
+
 class BlacklistingAgentWrapper(Agent):
     """
     An Agent wrapper which will prevent access to IP addresses being accessed
@@ -292,22 +321,11 @@ class SimpleHttpClient:
         self.user_agent = self.user_agent.encode("ascii")
 
         if self._ip_blacklist:
-            real_reactor = hs.get_reactor()
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
-            nameResolver = IPBlacklistingResolver(
-                real_reactor, self._ip_whitelist, self._ip_blacklist
+            self.reactor = BlacklistingReactorWrapper(
+                hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
             )
-
-            @implementer(IReactorPluggableNameResolver)
-            class Reactor:
-                def __getattr__(_self, attr):
-                    if attr == "nameResolver":
-                        return nameResolver
-                    else:
-                        return getattr(real_reactor, attr)
-
-            self.reactor = Reactor()
         else:
             self.reactor = hs.get_reactor()
 
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index e77f9587d0..3b756a7dc2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -16,7 +16,7 @@ import logging
 import urllib.parse
 from typing import List, Optional
 
-from netaddr import AddrFormatError, IPAddress
+from netaddr import AddrFormatError, IPAddress, IPSet
 from zope.interface import implementer
 
 from twisted.internet import defer
@@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
 
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.http.client import BlacklistingAgentWrapper
 from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -70,6 +71,7 @@ class MatrixFederationAgent:
         reactor: IReactorCore,
         tls_client_options_factory: Optional[FederationPolicyForHTTPS],
         user_agent: bytes,
+        ip_blacklist: IPSet,
         _srv_resolver: Optional[SrvResolver] = None,
         _well_known_resolver: Optional[WellKnownResolver] = None,
     ):
@@ -90,12 +92,18 @@ class MatrixFederationAgent:
         self.user_agent = user_agent
 
         if _well_known_resolver is None:
+            # Note that the name resolver has already been wrapped in a
+            # IPBlacklistingResolver by MatrixFederationHttpClient.
             _well_known_resolver = WellKnownResolver(
                 self._reactor,
-                agent=Agent(
+                agent=BlacklistingAgentWrapper(
+                    Agent(
+                        self._reactor,
+                        pool=self._pool,
+                        contextFactory=tls_client_options_factory,
+                    ),
                     self._reactor,
-                    pool=self._pool,
-                    contextFactory=tls_client_options_factory,
+                    ip_blacklist=ip_blacklist,
                 ),
                 user_agent=self.user_agent,
             )
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4e27f93b7a..c962994727 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,11 +26,10 @@ import treq
 from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
 from signedjson.sign import sign_json
-from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
+from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
@@ -45,7 +44,7 @@ from synapse.api.errors import (
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
     BlacklistingAgentWrapper,
-    IPBlacklistingResolver,
+    BlacklistingReactorWrapper,
     encode_query_args,
     readBodyToFile,
 )
@@ -221,31 +220,22 @@ class MatrixFederationHttpClient:
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
 
-        real_reactor = hs.get_reactor()
-
         # We need to use a DNS resolver which filters out blacklisted IP
         # addresses, to prevent DNS rebinding.
-        nameResolver = IPBlacklistingResolver(
-            real_reactor, None, hs.config.federation_ip_range_blacklist
+        self.reactor = BlacklistingReactorWrapper(
+            hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
         )
 
-        @implementer(IReactorPluggableNameResolver)
-        class Reactor:
-            def __getattr__(_self, attr):
-                if attr == "nameResolver":
-                    return nameResolver
-                else:
-                    return getattr(real_reactor, attr)
-
-        self.reactor = Reactor()
-
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
         user_agent = user_agent.encode("ascii")
 
         self.agent = MatrixFederationAgent(
-            self.reactor, tls_client_options_factory, user_agent
+            self.reactor,
+            tls_client_options_factory,
+            user_agent,
+            hs.config.federation_ip_range_blacklist,
         )
 
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eff0975b6a..0e845212a9 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -100,7 +100,7 @@ class HttpPusher:
         if "url" not in self.data:
             raise PusherConfigException("'url' required in data for HTTP pusher")
         self.url = self.data["url"]
-        self.http_client = hs.get_proxied_http_client()
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self.data_minus_url = {}
         self.data_minus_url.update(self.data)
         del self.data_minus_url["url"]
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..90940ff185 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet):
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
             raise SynapseError(400, "Invalid user type")
 
+        if "mac" not in body:
+            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
         got_mac = body["mac"]
 
         want_mac_builder = hmac.new(
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..83beb02b05 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -66,7 +66,7 @@ class MediaRepository:
     def __init__(self, hs):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
diff --git a/synapse/server.py b/synapse/server.py
index b017e3489f..9af759626e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -350,17 +350,46 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_simple_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client with no special configuration.
+        """
         return SimpleHttpClient(self)
 
     @cache_in_self
     def get_proxied_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client that uses configured HTTP(S) proxies.
+        """
+        return SimpleHttpClient(
+            self,
+            http_proxy=os.getenvb(b"http_proxy"),
+            https_proxy=os.getenvb(b"HTTPS_PROXY"),
+        )
+
+    @cache_in_self
+    def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
+        based on the IP range blacklist.
+        """
         return SimpleHttpClient(
             self,
+            ip_blacklist=self.config.ip_range_blacklist,
             http_proxy=os.getenvb(b"http_proxy"),
             https_proxy=os.getenvb(b"HTTPS_PROXY"),
         )
 
     @cache_in_self
+    def get_federation_http_client(self) -> MatrixFederationHttpClient:
+        """
+        An HTTP client for federation.
+        """
+        tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
+            self.config
+        )
+        return MatrixFederationHttpClient(self, tls_client_options_factory)
+
+    @cache_in_self
     def get_room_creation_handler(self) -> RoomCreationHandler:
         return RoomCreationHandler(self)
 
@@ -515,13 +544,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         return PusherPool(self)
 
     @cache_in_self
-    def get_http_client(self) -> MatrixFederationHttpClient:
-        tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
-            self.config
-        )
-        return MatrixFederationHttpClient(self, tls_client_options_factory)
-
-    @cache_in_self
     def get_media_repository_resource(self) -> MediaRepositoryResource:
         # build the media repo resource. This indirects through the HomeServer
         # to ensure that we only have a single instance of
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..ffc504ce77 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -252,9 +252,88 @@ async def _get_auth_chain_difference(
         Set of event IDs
     """
 
-    difference = await state_res_store.get_auth_chain_difference(
-        [set(state_set.values()) for state_set in state_sets]
-    )
+    # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+    # all events passed to it (and their auth chains) have been persisted
+    # previously. This is not the case for any events in the `event_map`, and so
+    # we need to manually handle those events.
+    #
+    # We do this by:
+    #   1. calculating the auth chain difference for the state sets based on the
+    #      events in `event_map` alone
+    #   2. replacing any events in the state_sets that are also in `event_map`
+    #      with their auth events (recursively), and then calling
+    #      `store.get_auth_chain_difference` as normal
+    #   3. adding the results of 1 and 2 together.
+
+    # Map from event ID in `event_map` to their auth event IDs, and their auth
+    # event IDs if they appear in the `event_map`. This is the intersection of
+    # the event's auth chain with the events in the `event_map` *plus* their
+    # auth event IDs.
+    events_to_auth_chain = {}  # type: Dict[str, Set[str]]
+    for event in event_map.values():
+        chain = {event.event_id}
+        events_to_auth_chain[event.event_id] = chain
+
+        to_search = [event]
+        while to_search:
+            for auth_id in to_search.pop().auth_event_ids():
+                chain.add(auth_id)
+                auth_event = event_map.get(auth_id)
+                if auth_event:
+                    to_search.append(auth_event)
+
+    # We now a) calculate the auth chain difference for the unpersisted events
+    # and b) work out the state sets to pass to the store.
+    #
+    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # much simpler calculation.
+    if event_map:
+        # The list of state sets to pass to the store, where each state set is a set
+        # of the event ids making up the state. This is similar to `state_sets`,
+        # except that (a) we only have event ids, not the complete
+        # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+        # unpersisted events and replaced them with the persisted events in
+        # their auth chain.
+        state_sets_ids = []  # type: List[Set[str]]
+
+        # For each state set, the unpersisted event IDs reachable (by their auth
+        # chain) from the events in that set.
+        unpersisted_set_ids = []  # type: List[Set[str]]
+
+        for state_set in state_sets:
+            set_ids = set()  # type: Set[str]
+            state_sets_ids.append(set_ids)
+
+            unpersisted_ids = set()  # type: Set[str]
+            unpersisted_set_ids.append(unpersisted_ids)
+
+            for event_id in state_set.values():
+                event_chain = events_to_auth_chain.get(event_id)
+                if event_chain is not None:
+                    # We have an event in `event_map`. We add all the auth
+                    # events that it references (that aren't also in `event_map`).
+                    set_ids.update(e for e in event_chain if e not in event_map)
+
+                    # We also add the full chain of unpersisted event IDs
+                    # referenced by this state set, so that we can work out the
+                    # auth chain difference of the unpersisted events.
+                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                else:
+                    set_ids.add(event_id)
+
+        # The auth chain difference of the unpersisted events of the state sets
+        # is calculated by taking the difference between the union and
+        # intersections.
+        union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+        intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+        difference_from_event_map = union - intersection  # type: Collection[str]
+    else:
+        difference_from_event_map = ()
+        state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
+    difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
+    difference.update(difference_from_event_map)
 
     return difference