diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8997c1f9e7..8d28076d92 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -565,7 +565,7 @@ def run(hs):
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
)
try:
- yield hs.get_simple_http_client().put_json(
+ yield hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 2d2c1e54df..75bb904718 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -234,8 +234,8 @@ def setup_logging(
# make sure that the first thing we log is a thing we can grep backwards
# for
- logging.warn("***** STARTING SERVER *****")
- logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
+ logging.warning("***** STARTING SERVER *****")
+ logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
return logger
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 595706d01a..545d719652 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -555,7 +555,7 @@ class FederationClient(FederationBase):
Note that this does not append any events to any graphs.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations (Iterable[str]): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index b754a09d7a..a5b36b1827 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -360,20 +360,20 @@ class PerDestinationQueue(object):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
- now_stream_id, results = yield self._store.get_devices_by_remote(
+ now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type=edu_type,
content=content,
)
- for content in results
+ for (edu_type, content) in results
]
- assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+ assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
return (edus, now_stream_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 16b4617f68..26ef5e150c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ user_id, "self_signing"
+ )
+
+ return {
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ "master_key": master_key,
+ "self_signing_key": self_signing_key,
+ }
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 526379c6f7..c4632f8984 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
ignore_backoff=True,
)
except CodeMessageException as e:
- logging.warn("Error retrieving alias")
+ logging.warning("Error retrieving alias")
if e.code == 404:
result = None
else:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 0449034a4e..f09a0b73c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- hs.get_federation_registry().register_query_handler(
+ federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -208,13 +219,15 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
@@ -252,7 +265,7 @@ class E2eKeysHandler(object):
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master|self_signing|user_signing) -> user_id -> key
+ (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -344,7 +357,16 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- return {"device_keys": res}
+ ret = {"device_keys": res}
+
+ # add in the cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, None
+ )
+
+ ret.update(cross_signing_keys)
+
+ return ret
@trace
@defer.inlineCallbacks
@@ -1058,3 +1080,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+ """Handles incoming signing key updates from federation and updates the DB"""
+
+ def __init__(self, hs, e2e_keys_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.e2e_keys_handler = e2e_keys_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="signing_key_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_signing_key_update(self, origin, edu_content):
+ """Called on incoming signing key update from federation. Responsible for
+ parsing the EDU and adding to pending updates list.
+
+ Args:
+ origin (string): the server that sent the EDU
+ edu_content (dict): the contents of the EDU
+ """
+
+ user_id = edu_content.pop("user_id")
+ master_key = edu_content.pop("master_key", None)
+ self_signing_key = edu_content.pop("self_signing_key", None)
+
+ if get_domain_from_id(user_id) != origin:
+ logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+ return
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We don't share any rooms with this user. Ignore update, as we
+ # probably won't get any further updates.
+ return
+
+ self._pending_updates.setdefault(user_id, []).append(
+ (master_key, self_signing_key)
+ )
+
+ yield self._handle_signing_key_updates(user_id)
+
+ @defer.inlineCallbacks
+ def _handle_signing_key_updates(self, user_id):
+ """Actually handle pending updates.
+
+ Args:
+ user_id (string): the user whose updates we are processing
+ """
+
+ device_handler = self.e2e_keys_handler.device_handler
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ device_ids = []
+
+ logger.info("pending updates: %r", pending_updates)
+
+ for master_key, self_signing_key in pending_updates:
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "master", master_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(
+ self_signing_key
+ )
+ device_ids.append(verify_key.version)
+
+ yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a932d3085f..dab6be9573 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1106,7 +1106,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the
- server `target_host`.
+ servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
event that we can fill out and sign. This is then sent to the
@@ -1115,6 +1115,15 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we
have finished processing the join.
+
+ Args:
+ target_hosts (Iterable[str]): List of servers to attempt to join the room with.
+
+ room_id (str): The ID of the room to join.
+
+ joinee (str): The User ID of the joining user.
+
+ content (dict): The event content to use for the join event.
"""
logger.debug("Joining %s to %s", joinee, room_id)
@@ -1174,6 +1183,22 @@ class FederationHandler(BaseHandler):
yield self._persist_auth_tree(origin, auth_chain, state, event)
+ # Check whether this room is the result of an upgrade of a room we already know
+ # about. If so, migrate over user information
+ predecessor = yield self.store.get_room_predecessor(room_id)
+ if not predecessor:
+ return
+ old_room_id = predecessor["room_id"]
+ logger.debug(
+ "Found predecessor for %s during remote join: %s", room_id, old_room_id
+ )
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ yield member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, room_id
+ )
+
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -2442,6 +2467,8 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
@@ -2502,6 +2529,7 @@ class FederationHandler(BaseHandler):
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 0182e5b432..e92b2eafd5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id,
new_version, # args for _upgrade_room
)
+
return ret
@defer.inlineCallbacks
@@ -189,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
requester, old_room_id, new_room_id, old_room_state
)
- # and finally, shut down the PLs in the old room, and update them in the new
+ # Copy over user push rules, tags and migrate room directory state
+ yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, new_room_id
+ )
+
+ # finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9a940d2c05..06d09c2947 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,10 +203,6 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- # Copy over user state if we're joining an upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -455,11 +451,6 @@ class RoomMemberHandler(object):
requester, remote_room_hosts, room_id, target, content
)
- # Copy over user state if this is a join on an remote upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
-
return remote_join_response
elif effective_membership_state == Membership.LEAVE:
@@ -498,36 +489,72 @@ class RoomMemberHandler(object):
return res
@defer.inlineCallbacks
- def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
- """Copy user-specific information when they join a new room if that new room is the
+ def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ """Upon our server becoming aware of an upgraded room, either by upgrading a room
+ ourselves or joining one, we can transfer over information from the previous room.
+
+ Copies user state (tags/push rules) for every local user that was in the old room, as
+ well as migrating the room directory state.
+
+ Args:
+ old_room_id (str): The ID of the old room
+
+ room_id (str): The ID of the new room
+
+ Returns:
+ Deferred
+ """
+ # Find all local users that were in the old room and copy over each user's state
+ users = yield self.store.get_users_in_room(old_room_id)
+ yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+
+ # Add new room to the room directory if the old room was there
+ # Remove old room from the room directory
+ old_room = yield self.store.get_room(old_room_id)
+ if old_room and old_room["is_public"]:
+ yield self.store.set_room_is_public(old_room_id, False)
+ yield self.store.set_room_is_public(room_id, True)
+
+ @defer.inlineCallbacks
+ def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ """Copy user-specific information when they join a new room when that new room is the
result of a room upgrade
Args:
- new_room_id (str): The ID of the room the user is joining
- user_id (str): The ID of the user
+ old_room_id (str): The ID of upgraded room
+ new_room_id (str): The ID of the new room
+ user_ids (Iterable[str]): User IDs to copy state for
Returns:
Deferred
"""
- # Check if the new room is an upgraded room
- predecessor = yield self.store.get_room_predecessor(new_room_id)
- if not predecessor:
- return
logger.debug(
- "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+ "Copying over room tags and push rules from %s to %s for users %s",
+ old_room_id,
new_room_id,
- predecessor,
+ user_ids,
)
- # It is an upgraded room. Copy over old tags
- yield self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], new_room_id, user_id
- )
- # Copy over push rules
- yield self.store.copy_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], new_room_id, user_id
- )
+ for user_id in user_ids:
+ try:
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ old_room_id, new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+ except Exception:
+ logger.exception(
+ "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+ "Skipping...",
+ old_room_id,
+ new_room_id,
+ user_id,
+ )
+ continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 29aa1e5aaf..8363d887a9 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
super().__init__(hs)
self._enabled = bool(hs.config.recaptcha_private_key)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 2df5b383b5..d4c285445e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
@@ -183,7 +184,15 @@ class SimpleHttpClient(object):
using HTTP in Matrix
"""
- def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ hs,
+ treq_args={},
+ ip_whitelist=None,
+ ip_blacklist=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
"""
Args:
hs (synapse.server.HomeServer)
@@ -192,6 +201,8 @@ class SimpleHttpClient(object):
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
+ http_proxy (bytes): proxy server to use for http connections. host[:port]
+ https_proxy (bytes): proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -236,11 +247,13 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(
+ self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
+ http_proxy=http_proxy,
+ https_proxy=https_proxy,
)
if self._ip_blacklist:
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 0000000000..be7b2ceb8e
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+ pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+ """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+ Wraps an existing HostnameEndpoint for the proxy.
+
+ When we get the connect() request from the connection pool (via the TLS wrapper),
+ we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+ CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+ in.
+
+ Args:
+ reactor: the Twisted reactor to use for the connection
+ proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+ proxy
+ host (bytes): hostname that we want to CONNECT to
+ port (int): port that we want to connect to
+ """
+
+ def __init__(self, reactor, proxy_endpoint, host, port):
+ self._reactor = reactor
+ self._proxy_endpoint = proxy_endpoint
+ self._host = host
+ self._port = port
+
+ def __repr__(self):
+ return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
+
+ def connect(self, protocolFactory):
+ f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ d = self._proxy_endpoint.connect(f)
+ # once the tcp socket connects successfully, we need to wait for the
+ # CONNECT to complete.
+ d.addCallback(lambda conn: f.on_connection)
+ return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+ """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+ Once the CONNECT completes, invokes the original ClientFactory to build the
+ HTTP Protocol object and run the rest of the connection.
+
+ Args:
+ dst_host (bytes): hostname that we want to CONNECT to
+ dst_port (int): port that we want to connect to
+ wrapped_factory (protocol.ClientFactory): The original Factory
+ """
+
+ def __init__(self, dst_host, dst_port, wrapped_factory):
+ self.dst_host = dst_host
+ self.dst_port = dst_port
+ self.wrapped_factory = wrapped_factory
+ self.on_connection = defer.Deferred()
+
+ def startedConnecting(self, connector):
+ return self.wrapped_factory.startedConnecting(connector)
+
+ def buildProtocol(self, addr):
+ wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+ return HTTPConnectProtocol(
+ self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ )
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.debug("Connection to proxy failed: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.debug("Connection to proxy lost: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+ """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+ Args:
+ host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ to put in the CONNECT request
+
+ port (int): The original HTTP(s) port to put in the CONNECT request
+
+ wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+ HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+ connected_deferred (Deferred): a Deferred which will be callbacked with
+ wrapped_protocol when the CONNECT completes
+ """
+
+ def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ self.host = host
+ self.port = port
+ self.wrapped_protocol = wrapped_protocol
+ self.connected_deferred = connected_deferred
+ self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+ def connectionMade(self):
+ self.http_setup_client.makeConnection(self.transport)
+
+ def connectionLost(self, reason=connectionDone):
+ if self.wrapped_protocol.connected:
+ self.wrapped_protocol.connectionLost(reason)
+
+ self.http_setup_client.connectionLost(reason)
+
+ if not self.connected_deferred.called:
+ self.connected_deferred.errback(reason)
+
+ def proxyConnected(self, _):
+ self.wrapped_protocol.makeConnection(self.transport)
+
+ self.connected_deferred.callback(self.wrapped_protocol)
+
+ # Get any pending data from the http buf and forward it to the original protocol
+ buf = self.http_setup_client.clearLineBuffer()
+ if buf:
+ self.wrapped_protocol.dataReceived(buf)
+
+ def dataReceived(self, data):
+ # if we've set up the HTTP protocol, we can send the data there
+ if self.wrapped_protocol.connected:
+ return self.wrapped_protocol.dataReceived(data)
+
+ # otherwise, we must still be setting up the connection: send the data to the
+ # setup client
+ return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+ """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+ Args:
+ host (bytes): The hostname to send in the CONNECT message
+ port (int): The port to send in the CONNECT message
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ self.on_connected = defer.Deferred()
+
+ def connectionMade(self):
+ logger.debug("Connected to proxy, sending CONNECT")
+ self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+ self.endHeaders()
+
+ def handleStatus(self, version, status, message):
+ logger.debug("Got Status: %s %s %s", status, message, version)
+ if status != b"200":
+ raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+ def handleEndHeaders(self):
+ logger.debug("End Headers")
+ self.on_connected.callback(None)
+
+ def handleResponse(self, body):
+ pass
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 0000000000..332da02a8d
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+ """An Agent implementation which will use an HTTP proxy if one was requested
+
+ Args:
+ reactor: twisted reactor to place outgoing
+ connections.
+
+ contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ verification parameters of OpenSSL. The default is to use a
+ `BrowserLikePolicyForHTTPS`, so unless you have special
+ requirements you can leave this as-is.
+
+ connectTimeout (float): The amount of time that this Agent will wait
+ for the peer to accept a connection.
+
+ bindAddress (bytes): The local address for client sockets to bind to.
+
+ pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ non-persistent pool instance will be created.
+ """
+
+ def __init__(
+ self,
+ reactor,
+ contextFactory=BrowserLikePolicyForHTTPS(),
+ connectTimeout=None,
+ bindAddress=None,
+ pool=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
+ _AgentBase.__init__(self, reactor, pool)
+
+ self._endpoint_kwargs = {}
+ if connectTimeout is not None:
+ self._endpoint_kwargs["timeout"] = connectTimeout
+ if bindAddress is not None:
+ self._endpoint_kwargs["bindAddress"] = bindAddress
+
+ self.http_proxy_endpoint = _http_proxy_endpoint(
+ http_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self.https_proxy_endpoint = _http_proxy_endpoint(
+ https_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self._policy_for_https = contextFactory
+ self._reactor = reactor
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Issue a request to the server indicated by the given uri.
+
+ Supports `http` and `https` schemes.
+
+ An existing connection from the connection pool may be used or a new one may be
+ created.
+
+ See also: twisted.web.iweb.IAgent.request
+
+ Args:
+ method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+ uri (bytes): The location of the resource to request.
+
+ headers (Headers|None): Extra headers to send with the request
+
+ bodyProducer (IBodyProducer|None): An object which can generate bytes to
+ make up the body of this request (for example, the properly encoded
+ contents of a file for a file upload). Or, None if the request is to
+ have no body.
+
+ Returns:
+ Deferred[IResponse]: completes when the header of the response has
+ been received (regardless of the response status code).
+ """
+ uri = uri.strip()
+ if not _VALID_URI.match(uri):
+ raise ValueError("Invalid URI {!r}".format(uri))
+
+ parsed_uri = URI.fromBytes(uri)
+ pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ request_path = parsed_uri.originForm
+
+ if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ pool_key = ("http-proxy", self.http_proxy_endpoint)
+ endpoint = self.http_proxy_endpoint
+ request_path = uri
+ elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self.https_proxy_endpoint,
+ parsed_uri.host,
+ parsed_uri.port,
+ )
+ else:
+ # not using a proxy
+ endpoint = HostnameEndpoint(
+ self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+ )
+
+ logger.debug("Requesting %s via %s", uri, endpoint)
+
+ if parsed_uri.scheme == b"https":
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ parsed_uri.host, parsed_uri.port
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+ elif parsed_uri.scheme == b"http":
+ pass
+ else:
+ return defer.fail(
+ Failure(
+ SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+ )
+ )
+
+ return self._requestWithEndpoint(
+ pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+ )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+ """Parses an http proxy setting and returns an endpoint for the proxy
+
+ Args:
+ proxy (bytes|None): the proxy setting
+ reactor: reactor to be used to connect to the proxy
+ kwargs: other args to be passed to HostnameEndpoint
+
+ Returns:
+ interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+ or None
+ """
+ if proxy is None:
+ return None
+
+ # currently we only support hostname:port. Some apps also support
+ # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
+ # proxy.
+
+ host, port = parse_host_port(proxy, default_port=1080)
+ return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+ # could have sworn we had one of these somewhere else...
+ if b":" in hostport:
+ host, port = hostport.rsplit(b":", 1)
+ try:
+ port = int(port)
+ return host, port
+ except ValueError:
+ # the thing after the : wasn't a valid port; presumably this is an
+ # IPv6 address.
+ pass
+
+ return hostport, default_port
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 7dde2ad055..e994037be6 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -103,7 +103,7 @@ class HttpPusher(object):
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 00a7dd6d09..24a0ce74f2 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet):
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5a25b6b3fc..531d923f76 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
diff --git a/synapse/server.py b/synapse/server.py
index 0b81af646c..f8aeebcff8 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
# Imports required for the default HomeServer() implementation
import abc
import logging
+import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
@@ -168,6 +169,7 @@ class HomeServer(object):
"filtering",
"http_client_context_factory",
"simple_http_client",
+ "proxied_http_client",
"media_repository",
"media_repository_resource",
"federation_transport_client",
@@ -311,6 +313,13 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
+ def build_proxied_http_client(self):
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
+ )
+
def build_room_creation_handler(self):
return RoomCreationHandler(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 83d1f11283..b5e0b57095 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
+import synapse.http.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -38,6 +39,14 @@ class HomeServer(object):
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
+ def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ or support any HTTP_PROXY settings"""
+ pass
+ def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ but does support HTTP_PROXY settings"""
+ pass
def get_deactivate_account_handler(
self,
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..71f62036c0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = set(r[0] for r in updates)
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
-
- update_context = update[3]
- update_stream_id = update[2]
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
- previous_update_stream_id, _ = query_map.get(key, (0, None))
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
- if update_stream_id > previous_update_stream_id:
- query_map[key] = (update_stream_id, update_context)
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
+ if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
+
return now_stream_id, results
- def _get_devices_by_remote_txn(
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
return results
|