summary refs log tree commit diff
path: root/synapse/http/replicationagent.py
blob: 800f21873d6673f57f131b3e9111cfdd15de8c5b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Optional

from zope.interface import implementer

from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.python.failure import Failure
from twisted.web.client import URI, HTTPConnectionPool, _AgentBase
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import (
    IAgent,
    IAgentEndpointFactory,
    IBodyProducer,
    IPolicyForHTTPS,
    IResponse,
)

from synapse.config.workers import InstanceLocationConfig
from synapse.types import ISynapseReactor

logger = logging.getLogger(__name__)


@implementer(IAgentEndpointFactory)
class ReplicationEndpointFactory:
    """Connect to a given TCP socket"""

    def __init__(
        self,
        reactor: ISynapseReactor,
        instance_map: Dict[str, InstanceLocationConfig],
        context_factory: IPolicyForHTTPS,
    ) -> None:
        self.reactor = reactor
        self.instance_map = instance_map
        self.context_factory = context_factory

    def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
        """
        This part of the factory decides what kind of endpoint is being connected to.

        Args:
            uri: The pre-parsed URI object containing all the uri data

        Returns: The correct client endpoint object
        """
        # The given URI has a special scheme and includes the worker name. The
        # actual connection details are pulled from the instance map.
        worker_name = uri.netloc.decode("utf-8")
        scheme = self.instance_map[worker_name].scheme()

        if scheme in ("http", "https"):
            endpoint = HostnameEndpoint(
                self.reactor,
                self.instance_map[worker_name].host,
                self.instance_map[worker_name].port,
            )
            if scheme == "https":
                endpoint = wrapClientTLS(
                    # The 'port' argument below isn't actually used by the function
                    self.context_factory.creatorForNetloc(
                        self.instance_map[worker_name].host,
                        self.instance_map[worker_name].port,
                    ),
                    endpoint,
                )
            return endpoint
        else:
            raise SchemeNotSupported(f"Unsupported scheme: {scheme}")


@implementer(IAgent)
class ReplicationAgent(_AgentBase):
    """
    Client for connecting to replication endpoints via HTTP and HTTPS.

    Much of this code is copied from Twisted's twisted.web.client.Agent.
    """

    def __init__(
        self,
        reactor: ISynapseReactor,
        instance_map: Dict[str, InstanceLocationConfig],
        contextFactory: IPolicyForHTTPS,
        connectTimeout: Optional[float] = None,
        bindAddress: Optional[bytes] = None,
        pool: Optional[HTTPConnectionPool] = None,
    ):
        """
        Create a ReplicationAgent.

        Args:
            reactor: A reactor for this Agent to place outgoing connections.
            contextFactory: A factory for TLS contexts, to control the
                verification parameters of OpenSSL.  The default is to use a
                BrowserLikePolicyForHTTPS, so unless you have special
                requirements you can leave this as-is.
            connectTimeout: The amount of time that this Agent will wait
                for the peer to accept a connection.
            bindAddress: The local address for client sockets to bind to.
            pool: An HTTPConnectionPool instance, or None, in which
                case a non-persistent HTTPConnectionPool instance will be
                created.
        """
        _AgentBase.__init__(self, reactor, pool)
        endpoint_factory = ReplicationEndpointFactory(
            reactor, instance_map, contextFactory
        )
        self._endpointFactory = endpoint_factory

    def request(
        self,
        method: bytes,
        uri: bytes,
        headers: Optional[Headers] = None,
        bodyProducer: Optional[IBodyProducer] = None,
    ) -> "defer.Deferred[IResponse]":
        """
        Issue a request to the server indicated by the given uri.

        An existing connection from the connection pool may be used or a new
        one may be created.

        Currently, HTTP and HTTPS schemes are supported in uri.

        This is copied from twisted.web.client.Agent, except:

        * It uses a different pool key (combining the host & port).
        * It does not call _ensureValidURI(...) since it breaks on some
          UNIX paths.

        See: twisted.web.iweb.IAgent.request
        """
        parsedURI = URI.fromBytes(uri)
        try:
            endpoint = self._endpointFactory.endpointForURI(parsedURI)
        except SchemeNotSupported:
            return defer.fail(Failure())

        # This sets the Pool key to be:
        #  (http(s), <host:ip>)
        key = (parsedURI.scheme, parsedURI.netloc)

        # _requestWithEndpoint comes from _AgentBase class
        return self._requestWithEndpoint(
            key,
            endpoint,
            method,
            parsedURI,
            headers,
            bodyProducer,
            parsedURI.originForm,
        )