1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
|
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
from typing import Dict, Optional
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import (
HostnameEndpoint,
UNIXClientEndpoint,
wrapClientTLS,
)
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.python.failure import Failure
from twisted.web.client import URI, HTTPConnectionPool, _AgentBase
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import (
IAgent,
IAgentEndpointFactory,
IBodyProducer,
IPolicyForHTTPS,
IResponse,
)
from synapse.config.workers import (
InstanceLocationConfig,
InstanceTcpLocationConfig,
InstanceUnixLocationConfig,
)
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@implementer(IAgentEndpointFactory)
class ReplicationEndpointFactory:
"""Connect to a given TCP or UNIX socket"""
def __init__(
self,
reactor: ISynapseReactor,
instance_map: Dict[str, InstanceLocationConfig],
context_factory: IPolicyForHTTPS,
) -> None:
self.reactor = reactor
self.instance_map = instance_map
self.context_factory = context_factory
def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
"""
This part of the factory decides what kind of endpoint is being connected to.
Args:
uri: The pre-parsed URI object containing all the uri data
Returns: The correct client endpoint object
"""
# The given URI has a special scheme and includes the worker name. The
# actual connection details are pulled from the instance map.
worker_name = uri.netloc.decode("utf-8")
location_config = self.instance_map[worker_name]
scheme = location_config.scheme()
if isinstance(location_config, InstanceTcpLocationConfig):
endpoint = HostnameEndpoint(
self.reactor,
location_config.host,
location_config.port,
)
if scheme == "https":
endpoint = wrapClientTLS(
# The 'port' argument below isn't actually used by the function
self.context_factory.creatorForNetloc(
location_config.host.encode("utf-8"),
location_config.port,
),
endpoint,
)
return endpoint
elif isinstance(location_config, InstanceUnixLocationConfig):
return UNIXClientEndpoint(self.reactor, location_config.path)
else:
raise SchemeNotSupported(f"Unsupported scheme: {scheme}")
@implementer(IAgent)
class ReplicationAgent(_AgentBase):
"""
Client for connecting to replication endpoints via HTTP and HTTPS.
Much of this code is copied from Twisted's twisted.web.client.Agent.
"""
def __init__(
self,
reactor: ISynapseReactor,
instance_map: Dict[str, InstanceLocationConfig],
contextFactory: IPolicyForHTTPS,
connectTimeout: Optional[float] = None,
bindAddress: Optional[bytes] = None,
pool: Optional[HTTPConnectionPool] = None,
):
"""
Create a ReplicationAgent.
Args:
reactor: A reactor for this Agent to place outgoing connections.
contextFactory: A factory for TLS contexts, to control the
verification parameters of OpenSSL. The default is to use a
BrowserLikePolicyForHTTPS, so unless you have special
requirements you can leave this as-is.
connectTimeout: The amount of time that this Agent will wait
for the peer to accept a connection.
bindAddress: The local address for client sockets to bind to.
pool: An HTTPConnectionPool instance, or None, in which
case a non-persistent HTTPConnectionPool instance will be
created.
"""
_AgentBase.__init__(self, reactor, pool)
endpoint_factory = ReplicationEndpointFactory(
reactor, instance_map, contextFactory
)
self._endpointFactory = endpoint_factory
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> "defer.Deferred[IResponse]":
"""
Issue a request to the server indicated by the given uri.
An existing connection from the connection pool may be used or a new
one may be created.
Currently, HTTP, HTTPS and UNIX schemes are supported in uri.
This is copied from twisted.web.client.Agent, except:
* It uses a different pool key (combining the scheme with either host & port or
socket path).
* It does not call _ensureValidURI(...) as the strictness of IDNA2008 is not
required when using a worker's name as a 'hostname' for Synapse HTTP
Replication machinery. Specifically, this allows a range of ascii characters
such as '+' and '_' in hostnames/worker's names.
See: twisted.web.iweb.IAgent.request
"""
parsedURI = URI.fromBytes(uri)
try:
endpoint = self._endpointFactory.endpointForURI(parsedURI)
except SchemeNotSupported:
return defer.fail(Failure())
worker_name = parsedURI.netloc.decode("utf-8")
key_scheme = self._endpointFactory.instance_map[worker_name].scheme()
key_netloc = self._endpointFactory.instance_map[worker_name].netloc()
# This sets the Pool key to be:
# (http(s), <host:port>) or (unix, <socket_path>)
key = (key_scheme, key_netloc)
# _requestWithEndpoint comes from _AgentBase class
return self._requestWithEndpoint(
key,
endpoint,
method,
parsedURI,
headers,
bodyProducer,
parsedURI.originForm,
)
|