diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a0b7cb7963..da2f5e8cfd 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -31,6 +31,9 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
+ def __init__(self, hs):
+ pass
+
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 37ee469fa2..9ba3151713 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -24,6 +24,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
+from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
@@ -50,7 +51,33 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
class FederationClient(FederationBase):
+ def __init__(self, hs):
+ super(FederationClient, self).__init__(hs)
+
+ self.pdu_destination_tried = {}
+ self._clock.looping_call(
+ self._clear_tried_cache, 60 * 1000,
+ )
+
+ def _clear_tried_cache(self):
+ """Clear pdu_destination_tried cache"""
+ now = self._clock.time_msec()
+
+ old_dict = self.pdu_destination_tried
+ self.pdu_destination_tried = {}
+
+ for event_id, destination_dict in old_dict.items():
+ destination_dict = {
+ dest: time
+ for dest, time in destination_dict.items()
+ if time + PDU_RETRY_TIME_MS > now
+ }
+ if destination_dict:
+ self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
@@ -233,12 +260,19 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache:
- e = self._get_pdu_cache.get(event_id)
- if e:
- defer.returnValue(e)
+ ev = self._get_pdu_cache.get(event_id)
+ if ev:
+ defer.returnValue(ev)
+
+ pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None
for destination in destinations:
+ now = self._clock.time_msec()
+ last_attempt = pdu_attempts.get(destination, 0)
+ if last_attempt + PDU_RETRY_TIME_MS > now:
+ continue
+
try:
limiter = yield get_retry_limiter(
destination,
@@ -266,25 +300,19 @@ class FederationClient(FederationBase):
break
- except SynapseError:
- logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
- )
- continue
- except CodeMessageException as e:
- if 400 <= e.code < 500:
- raise
+ pdu_attempts[destination] = now
+ except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
- continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
+ pdu_attempts[destination] = now
+
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
@@ -311,6 +339,42 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs.
"""
+ try:
+ # First we try and ask for just the IDs, as thats far quicker if
+ # we have most of the state and auth_chain already.
+ # However, this may 404 if the other side has an old synapse.
+ result = yield self.transport_layer.get_room_state_ids(
+ destination, room_id, event_id=event_id,
+ )
+
+ state_event_ids = result["pdu_ids"]
+ auth_event_ids = result.get("auth_chain_ids", [])
+
+ fetched_events, failed_to_fetch = yield self.get_events(
+ [destination], room_id, set(state_event_ids + auth_event_ids)
+ )
+
+ if failed_to_fetch:
+ logger.warn("Failed to get %r", failed_to_fetch)
+
+ event_map = {
+ ev.event_id: ev for ev in fetched_events
+ }
+
+ pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+ auth_chain = [
+ event_map[e_id] for e_id in auth_event_ids if e_id in event_map
+ ]
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue((pdus, auth_chain))
+ except HttpResponseException as e:
+ if e.code == 400 or e.code == 404:
+ logger.info("Failed to use get_room_state_ids API, falling back")
+ else:
+ raise e
+
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
@@ -324,12 +388,26 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
+ seen_events = yield self.store.get_events([
+ ev.event_id for ev in itertools.chain(pdus, auth_chain)
+ ])
+
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus, outlier=True
+ destination,
+ [p for p in pdus if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_pdus.extend(
+ seen_events[p.event_id] for p in pdus if p.event_id in seen_events
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
+ destination,
+ [p for p in auth_chain if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_auth.extend(
+ seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
)
signed_auth.sort(key=lambda e: e.depth)
@@ -337,6 +415,67 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
+ def get_events(self, destinations, room_id, event_ids, return_local=True):
+ """Fetch events from some remote destinations, checking if we already
+ have them.
+
+ Args:
+ destinations (list)
+ room_id (str)
+ event_ids (list)
+ return_local (bool): Whether to include events we already have in
+ the DB in the returned list of events
+
+ Returns:
+ Deferred: A deferred resolving to a 2-tuple where the first is a list of
+ events and the second is a list of event ids that we failed to fetch.
+ """
+ if return_local:
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
+ signed_events = seen_events.values()
+ else:
+ seen_events = yield self.store.have_events(event_ids)
+ signed_events = []
+
+ failed_to_fetch = set()
+
+ missing_events = set(event_ids)
+ for k in seen_events:
+ missing_events.discard(k)
+
+ if not missing_events:
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ def random_server_list():
+ srvs = list(destinations)
+ random.shuffle(srvs)
+ return srvs
+
+ batch_size = 20
+ missing_events = list(missing_events)
+ for i in xrange(0, len(missing_events), batch_size):
+ batch = set(missing_events[i:i + batch_size])
+
+ deferreds = [
+ self.get_pdu(
+ destinations=random_server_list(),
+ event_id=e_id,
+ )
+ for e_id in batch
+ ]
+
+ res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ for success, result in res:
+ if success:
+ signed_events.append(result)
+ batch.discard(result.event_id)
+
+ # We removed all events we successfully fetched from `batch`
+ failed_to_fetch.update(batch)
+
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ @defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
@@ -411,14 +550,19 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict))
)
break
- except CodeMessageException:
- raise
+ except CodeMessageException as e:
+ if not 500 <= e.code < 600:
+ raise
+ else:
+ logger.warn(
+ "Failed to make_%s via %s: %s",
+ membership, destination, e.message
+ )
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
- raise
raise RuntimeError("Failed to send to any server.")
@@ -490,8 +634,14 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth,
"origin": destination,
})
- except CodeMessageException:
- raise
+ except CodeMessageException as e:
+ if not 500 <= e.code < 600:
+ raise
+ else:
+ logger.exception(
+ "Failed to send_join via %s: %s",
+ destination, e.message
+ )
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
@@ -551,6 +701,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
+ def get_public_rooms(self, destinations):
+ results_by_server = {}
+
+ @defer.inlineCallbacks
+ def _get_result(s):
+ if s == self.server_name:
+ defer.returnValue()
+
+ try:
+ result = yield self.transport_layer.get_public_rooms(s)
+ results_by_server[s] = result
+ except:
+ logger.exception("Error getting room list from server %r", s)
+
+ yield concurrently_execute(_get_result, destinations, 3)
+
+ defer.returnValue(results_by_server)
+
+ @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""
Params:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f1d231b9d8..aba19639c7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -19,11 +19,13 @@ from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu
+from synapse.util.async import Linearizer
from synapse.util.logutils import log_function
+from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
import synapse.metrics
-from synapse.api.errors import FederationError, SynapseError
+from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
@@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
class FederationServer(FederationBase):
+ def __init__(self, hs):
+ super(FederationServer, self).__init__(hs)
+
+ self.auth = hs.get_auth()
+
+ self._room_pdu_linearizer = Linearizer()
+ self._server_linearizer = Linearizer()
+
+ # We cache responses to state queries, as they take a while and often
+ # come in waves.
+ self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
+
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
@@ -83,11 +97,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
- pdus = yield self.handler.on_backfill_request(
- origin, room_id, versions, limit
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ pdus = yield self.handler.on_backfill_request(
+ origin, room_id, versions, limit
+ )
+
+ res = self._transaction_from_pdus(pdus).get_dict()
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+ defer.returnValue((200, res))
@defer.inlineCallbacks
@log_function
@@ -178,15 +195,59 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
- if event_id:
- pdus = yield self.handler.get_state_for_pdu(
- origin, room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
+ if not event_id:
+ raise NotImplementedError("Specify an event")
+
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
+ result = self._state_resp_cache.get((room_id, event_id))
+ if not result:
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.set(
+ (room_id, event_id),
+ self._on_context_state_request_compute(room_id, event_id)
+ )
+ else:
+ resp = yield result
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_state_ids_request(self, origin, room_id, event_id):
+ if not event_id:
+ raise NotImplementedError("Specify an event")
+
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
- for event in auth_chain:
+ pdus = yield self.handler.get_state_for_pdu(
+ room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+
+ defer.returnValue((200, {
+ "pdu_ids": [pdu.event_id for pdu in pdus],
+ "auth_chain_ids": [pdu.event_id for pdu in auth_chain],
+ }))
+
+ @defer.inlineCallbacks
+ def _on_context_state_request_compute(self, room_id, event_id):
+ pdus = yield self.handler.get_state_for_pdu(
+ room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+
+ for event in auth_chain:
+ # We sign these again because there was a bug where we
+ # incorrectly signed things the first time round
+ if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
@@ -194,13 +255,11 @@ class FederationServer(FederationBase):
self.hs.config.signing_key[0]
)
)
- else:
- raise NotImplementedError("Specify an event")
- defer.returnValue((200, {
+ defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }))
+ })
@defer.inlineCallbacks
@log_function
@@ -274,14 +333,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- defer.returnValue((200, {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }))
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ time_now = self._clock.time_msec()
+ auth_pdus = yield self.handler.on_event_auth(event_id)
+ res = {
+ "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+ }
+ defer.returnValue((200, res))
@defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, event_id):
+ def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
@@ -300,58 +361,41 @@ class FederationServer(FederationBase):
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
- auth_chain = [
- self.event_from_pdu_json(e)
- for e in content["auth_chain"]
- ]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ auth_chain = [
+ self.event_from_pdu_json(e)
+ for e in content["auth_chain"]
+ ]
+
+ signed_auth = yield self._check_sigs_and_hash_and_fetch(
+ origin, auth_chain, outlier=True
+ )
- ret = yield self.handler.on_query_auth(
- origin,
- event_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
+ ret = yield self.handler.on_query_auth(
+ origin,
+ event_id,
+ signed_auth,
+ content.get("rejects", []),
+ content.get("missing", []),
+ )
- time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [
- e.get_pdu_json(time_now)
- for e in ret["auth_chain"]
- ],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
+ time_now = self._clock.time_msec()
+ send_content = {
+ "auth_chain": [
+ e.get_pdu_json(time_now)
+ for e in ret["auth_chain"]
+ ],
+ "rejects": ret.get("rejects", []),
+ "missing": ret.get("missing", []),
+ }
defer.returnValue(
(200, send_content)
)
- @defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
- query = []
- for user_id, device_ids in content.get("device_keys", {}).items():
- if not device_ids:
- query.append((user_id, None))
- else:
- for device_id in device_ids:
- query.append((user_id, device_id))
-
- results = yield self.store.get_e2e_device_keys(query)
-
- json_result = {}
- for user_id, device_keys in results.items():
- for device_id, json_bytes in device_keys.items():
- json_result.setdefault(user_id, {})[device_id] = json.loads(
- json_bytes
- )
-
- defer.returnValue({"device_keys": json_result})
+ return self.on_query_request("client_keys", content)
@defer.inlineCallbacks
@log_function
@@ -377,11 +421,24 @@ class FederationServer(FederationBase):
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
- missing_events = yield self.handler.on_get_missing_events(
- origin, room_id, earliest_events, latest_events, limit, min_depth
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ logger.info(
+ "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+ " limit: %d, min_depth: %d",
+ earliest_events, latest_events, limit, min_depth
+ )
+ missing_events = yield self.handler.on_get_missing_events(
+ origin, room_id, earliest_events, latest_events, limit, min_depth
+ )
- time_now = self._clock.time_msec()
+ if len(missing_events) < 5:
+ logger.info(
+ "Returning %d events: %r", len(missing_events), missing_events
+ )
+ else:
+ logger.info("Returning %d events", len(missing_events))
+
+ time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
@@ -481,42 +538,59 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
- latest = yield self.store.get_latest_event_ids_in_room(
- pdu.room_id
- )
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest)
- latest |= seen
-
- missing_events = yield self.get_missing_events(
- origin,
- pdu.room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- )
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for e in missing_events:
- yield self._handle_new_pdu(
- origin,
- e,
- get_missing=False
- )
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ # We recalculate seen, since it may have changed.
+ have_seen = yield self.store.have_events(prevs)
+ seen = set(have_seen.keys())
+
+ if prevs - seen:
+ latest = yield self.store.get_latest_event_ids_in_room(
+ pdu.room_id
+ )
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest)
+ latest |= seen
+
+ logger.info(
+ "Missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
+
+ missing_events = yield self.get_missing_events(
+ origin,
+ pdu.room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ )
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ missing_events.sort(key=lambda x: x.depth)
+
+ for e in missing_events:
+ yield self._handle_new_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
+ logger.info(
+ "Still missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
fetch_state = True
if fetch_state:
@@ -531,7 +605,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id,
)
except:
- logger.warn("Failed to get state for event: %s", pdu.event_id)
+ logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 3e062a5eab..ea66a5dcbc 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -72,5 +72,7 @@ class ReplicationLayer(FederationClient, FederationServer):
self.hs = hs
+ super(ReplicationLayer, self).__init__(hs)
+
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5787f854d4..cb2ef0210c 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
+from synapse.util.metrics import measure_func
import synapse.metrics
import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer
- self._clock = hs.get_clock()
+ self.clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
# HACK to get unique tx id
- self._next_txn_id = int(self._clock.time_msec())
+ self._next_txn_id = int(self.clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -119,266 +119,215 @@ class TransactionQueue(object):
if not destinations:
return
- deferreds = []
-
for destination in destinations:
- deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
- (pdu, deferred, order)
+ (pdu, order)
)
- def chain(failure):
- if not deferred.called:
- deferred.errback(failure)
-
- def log_failure(f):
- logger.warn("Failed to send pdu to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
-
- deferreds.append(deferred)
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
- # NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if not self.can_send_to(destination):
return
- deferred = defer.Deferred()
- self.pending_edus_by_dest.setdefault(destination, []).append(
- (edu, deferred)
- )
+ self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- def chain(failure):
- if not deferred.called:
- deferred.errback(failure)
-
- def log_failure(f):
- logger.warn("Failed to send edu to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
-
- return deferred
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
- @defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
- deferred = defer.Deferred()
-
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
- ).append(
- (failure, deferred)
- )
-
- def chain(f):
- if not deferred.called:
- deferred.errback(f)
-
- def log_failure(f):
- logger.warn("Failed to send failure to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
+ ).append(failure)
- yield deferred
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
@defer.inlineCallbacks
- @log_function
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
+ while True:
+ # list of (pending_pdu, deferred, order)
+ if destination in self.pending_transactions:
+ # XXX: pending_transactions can get stuck on by a never-ending
+ # request at which point pending_pdus_by_dest just keeps growing.
+ # we need application-layer timeouts of some flavour of these
+ # requests
+ logger.debug(
+ "TX [%s] Transaction already in progress",
+ destination
+ )
+ return
- # list of (pending_pdu, deferred, order)
- if destination in self.pending_transactions:
- # XXX: pending_transactions can get stuck on by a never-ending
- # request at which point pending_pdus_by_dest just keeps growing.
- # we need application-layer timeouts of some flavour of these
- # requests
- logger.debug(
- "TX [%s] Transaction already in progress",
- destination
- )
- return
-
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- if pending_pdus:
- logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- destination, len(pending_pdus))
+ if pending_pdus:
+ logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ destination, len(pending_pdus))
- if not pending_pdus and not pending_edus and not pending_failures:
- logger.debug("TX [%s] Nothing to send", destination)
- return
+ if not pending_pdus and not pending_edus and not pending_failures:
+ logger.debug("TX [%s] Nothing to send", destination)
+ return
- try:
- self.pending_transactions[destination] = 1
+ yield self._send_new_transaction(
+ destination, pending_pdus, pending_edus, pending_failures
+ )
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ @measure_func("_send_new_transaction")
+ @defer.inlineCallbacks
+ def _send_new_transaction(self, destination, pending_pdus, pending_edus,
+ pending_failures):
# Sort based on the order field
- pending_pdus.sort(key=lambda t: t[2])
-
+ pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
- edus = [x[0] for x in pending_edus]
- failures = [x[0].get_dict() for x in pending_failures]
- deferreds = [
- x[1]
- for x in pending_pdus + pending_edus + pending_failures
- ]
-
- txn_id = str(self._next_txn_id)
-
- limiter = yield get_retry_limiter(
- destination,
- self._clock,
- self.store,
- )
+ edus = pending_edus
+ failures = [x.get_dict() for x in pending_failures]
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination, txn_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
- )
+ try:
+ self.pending_transactions[destination] = 1
- logger.debug("TX [%s] Persisting transaction...", destination)
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- transaction = Transaction.create_new(
- origin_server_ts=int(self._clock.time_msec()),
- transaction_id=txn_id,
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
+ txn_id = str(self._next_txn_id)
- self._next_txn_id += 1
+ limiter = yield get_retry_limiter(
+ destination,
+ self.clock,
+ self.store,
+ )
- yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination, txn_id,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures)
+ )
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures),
- )
+ logger.debug("TX [%s] Persisting transaction...", destination)
- with limiter:
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self._clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
- code = 200
-
- if response:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "Transaction returned error for %s: %s",
- e_id, r,
- )
- except HttpResponseException as e:
- code = e.code
- response = e.response
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
+ "TX [%s] {%s} Sending transaction [%s],"
+ " (PDUs: %d, EDUs: %d, failures: %d)",
+ destination, txn_id,
+ transaction.transaction_id,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures),
)
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
+ with limiter:
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+
+ if response:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "Transaction returned error for %s: %s",
+ e_id, r,
+ )
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
- logger.debug("TX [%s] Marked as delivered", destination)
-
- logger.debug("TX [%s] Yielding to callbacks...", destination)
-
- for deferred in deferreds:
- if code == 200:
- deferred.callback(None)
- else:
- deferred.errback(RuntimeError("Got status %d" % code))
-
- # Ensures we don't continue until all callbacks on that
- # deferred have fired
- try:
- yield deferred
- except:
- pass
-
- logger.debug("TX [%s] Yielded to callbacks", destination)
- except NotRetryingDestination:
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- except RuntimeError as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
- for deferred in deferreds:
- if not deferred.called:
- deferred.errback(e)
+ logger.debug("TX [%s] Marked as delivered", destination)
+
+ if code != 200:
+ for p in pdus:
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, destination
+ )
+ except NotRetryingDestination:
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
+ except RuntimeError as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
+
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
+ except Exception as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
- finally:
- # We want to be *very* sure we delete this after we stop processing
- self.pending_transactions.pop(destination, None)
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
- # Check to see if there is anything else to send.
- self._attempt_new_transaction(destination)
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index cd2841c4db..3d088e43cb 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -55,6 +55,28 @@ class TransportLayerClient(object):
)
@log_function
+ def get_room_state_ids(self, destination, room_id, event_id):
+ """ Requests all state for a given room from the given server at the
+ given event. Returns the state's event_id's
+
+ Args:
+ destination (str): The host name of the remote home server we want
+ to get the state from.
+ context (str): The name of the context we want the state of
+ event_id (str): The event we want the context at.
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug("get_room_state_ids dest=%s, room=%s",
+ destination, room_id)
+
+ path = PREFIX + "/state_ids/%s/" % room_id
+ return self.client.get_json(
+ destination, path=path, args={"event_id": event_id},
+ )
+
+ @log_function
def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server.
@@ -226,6 +248,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def get_public_rooms(self, remote_server):
+ path = PREFIX + "/publicRooms"
+
+ response = yield self.client.get_json(
+ destination=remote_server,
+ path=path,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5b6c7d11dd..37c0d4fbc4 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,13 +18,14 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import parse_json_object_from_request, parse_string
+from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.versionstring import get_version_string
import functools
import logging
-import simplejson as json
import re
+import synapse
logger = logging.getLogger(__name__)
@@ -37,7 +38,7 @@ class TransportLayerServer(JsonResource):
self.hs = hs
self.clock = hs.get_clock()
- super(TransportLayerServer, self).__init__(hs)
+ super(TransportLayerServer, self).__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
@@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
)
+class AuthenticationError(SynapseError):
+ """There was a problem authenticating the request"""
+ pass
+
+
+class NoAuthenticationError(AuthenticationError):
+ """The request had no authentication information"""
+ pass
+
+
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
@@ -67,7 +78,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
- def authenticate_request(self, request):
+ def authenticate_request(self, request, content):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -75,17 +86,10 @@ class Authenticator(object):
"signatures": {},
}
- content = None
- origin = None
+ if content is not None:
+ json_request["content"] = content
- if request.method in ["PUT", "POST"]:
- # TODO: Handle other method types? other content types?
- try:
- content_bytes = request.content.read()
- content = json.loads(content_bytes)
- json_request["content"] = content
- except:
- raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+ origin = None
def parse_auth_header(header_str):
try:
@@ -103,14 +107,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
- raise SynapseError(
+ raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
- raise SynapseError(
+ raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -121,7 +125,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
- raise SynapseError(
+ raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -130,38 +134,59 @@ class Authenticator(object):
logger.info("Request from %s", origin)
request.authenticated_entity = origin
- defer.returnValue((origin, content))
+ defer.returnValue(origin)
class BaseFederationServlet(object):
- def __init__(self, handler, authenticator, ratelimiter, server_name):
+ REQUIRE_AUTH = True
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name,
+ room_list_handler):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
+ self.room_list_handler = room_list_handler
- def _wrap(self, code):
+ def _wrap(self, func):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
- @functools.wraps(code)
- def new_code(request, *args, **kwargs):
+ @functools.wraps(func)
+ def new_func(request, *args, **kwargs):
+ content = None
+ if request.method in ["PUT", "POST"]:
+ # TODO: Handle other method types? other content types?
+ content = parse_json_object_from_request(request)
+
try:
- (origin, content) = yield authenticator.authenticate_request(request)
+ origin = yield authenticator.authenticate_request(request, content)
+ except NoAuthenticationError:
+ origin = None
+ if self.REQUIRE_AUTH:
+ logger.exception("authenticate_request failed")
+ raise
+ except:
+ logger.exception("authenticate_request failed")
+ raise
+
+ if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
- response = yield code(
+ response = yield func(
origin, content, request.args, *args, **kwargs
)
- except:
- logger.exception("authenticate_request failed")
- raise
+ else:
+ response = yield func(
+ origin, content, request.args, *args, **kwargs
+ )
+
defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish
- new_code.__self__ = code.__self__
+ new_func.__self__ = func.__self__
- return new_code
+ return new_func
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -269,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
)
+class FederationStateIdsServlet(BaseFederationServlet):
+ PATH = "/state_ids/(?P<room_id>[^/]*)/"
+
+ def on_GET(self, origin, content, query, room_id):
+ return self.handler.on_state_ids_request(
+ origin,
+ room_id,
+ query.get("event_id", [None])[0],
+ )
+
+
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/"
@@ -365,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
- @defer.inlineCallbacks
def on_POST(self, origin, content, query):
- response = yield self.handler.on_query_client_keys(origin, content)
- defer.returnValue((200, response))
+ return self.handler.on_query_client_keys(origin, content)
class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -386,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request(
- origin, content, event_id
+ origin, content, context, event_id
)
defer.returnValue((200, new_content))
@@ -418,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
+ REQUIRE_AUTH = False
+
@defer.inlineCallbacks
- def on_POST(self, request):
- content = parse_json_object_from_request(request)
+ def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -442,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception
defer.returnValue((200, {}))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
-
class OpenIdUserInfo(BaseFederationServlet):
"""
@@ -467,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo"
+ REQUIRE_AUTH = False
+
@defer.inlineCallbacks
- def on_GET(self, request):
- token = parse_string(request, "access_token")
+ def on_GET(self, origin, content, query):
+ token = query.get("access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -486,10 +518,58 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id}))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
+
+class PublicRoomList(BaseFederationServlet):
+ """
+ Fetch the public room list for this server.
+
+ This API returns information in the same format as /publicRooms on the
+ client API, but will only ever include local public rooms and hence is
+ intended for consumption by other home servers.
+
+ GET /publicRooms HTTP/1.1
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "chunk": [
+ {
+ "aliases": [
+ "#test:localhost"
+ ],
+ "guest_can_join": false,
+ "name": "test room",
+ "num_joined_members": 3,
+ "room_id": "!whkydVegtvatLfXmPN:localhost",
+ "world_readable": false
+ }
+ ],
+ "end": "END",
+ "start": "START"
+ }
+ """
+
+ PATH = "/publicRooms"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query):
+ data = yield self.room_list_handler.get_local_public_room_list()
+ defer.returnValue((200, data))
+
+
+class FederationVersionServlet(BaseFederationServlet):
+ PATH = "/version"
+
+ REQUIRE_AUTH = False
+
+ def on_GET(self, origin, content, query):
+ return defer.succeed((200, {
+ "server": {
+ "name": "Synapse",
+ "version": get_version_string(synapse)
+ },
+ }))
SERVLET_CLASSES = (
@@ -497,6 +577,7 @@ SERVLET_CLASSES = (
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
+ FederationStateIdsServlet,
FederationBackfillServlet,
FederationQueryServlet,
FederationMakeJoinServlet,
@@ -513,6 +594,8 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
OpenIdUserInfo,
+ PublicRoomList,
+ FederationVersionServlet,
)
@@ -523,4 +606,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
+ room_list_handler=hs.get_room_list_handler(),
).register(resource)
|