diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e8bfbe7cb5..3fa7b2315c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -19,11 +19,13 @@ from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu
+from synapse.util.async import Linearizer
from synapse.util.logutils import log_function
+from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
import synapse.metrics
-from synapse.api.errors import FederationError, SynapseError
+from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
@@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
class FederationServer(FederationBase):
+ def __init__(self, hs):
+ super(FederationServer, self).__init__(hs)
+
+ self.auth = hs.get_auth()
+
+ self._room_pdu_linearizer = Linearizer()
+ self._server_linearizer = Linearizer()
+
+ # We cache responses to state queries, as they take a while and often
+ # come in waves.
+ self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
+
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
@@ -83,11 +97,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
- pdus = yield self.handler.on_backfill_request(
- origin, room_id, versions, limit
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ pdus = yield self.handler.on_backfill_request(
+ origin, room_id, versions, limit
+ )
+
+ res = self._transaction_from_pdus(pdus).get_dict()
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+ defer.returnValue((200, res))
@defer.inlineCallbacks
@log_function
@@ -137,8 +154,8 @@ class FederationServer(FederationBase):
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"):
- for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(
+ for edu in (Edu(**x) for x in transaction.edus):
+ yield self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
@@ -161,26 +178,74 @@ class FederationServer(FederationBase):
)
defer.returnValue((200, response))
+ @defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
if edu_type in self.edu_handlers:
- self.edu_handlers[edu_type](origin, content)
+ try:
+ yield self.edu_handlers[edu_type](origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception as e:
+ logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
- if event_id:
- pdus = yield self.handler.get_state_for_pdu(
- origin, room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
+ if not event_id:
+ raise NotImplementedError("Specify an event")
+
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
+ result = self._state_resp_cache.get((room_id, event_id))
+ if not result:
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.set(
+ (room_id, event_id),
+ self._on_context_state_request_compute(room_id, event_id)
+ )
+ else:
+ resp = yield result
+
+ defer.returnValue((200, resp))
- for event in auth_chain:
+ @defer.inlineCallbacks
+ def on_state_ids_request(self, origin, room_id, event_id):
+ if not event_id:
+ raise NotImplementedError("Specify an event")
+
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
+ state_ids = yield self.handler.get_state_ids_for_pdu(
+ room_id, event_id,
+ )
+ auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
+
+ defer.returnValue((200, {
+ "pdu_ids": state_ids,
+ "auth_chain_ids": auth_chain_ids,
+ }))
+
+ @defer.inlineCallbacks
+ def _on_context_state_request_compute(self, room_id, event_id):
+ pdus = yield self.handler.get_state_for_pdu(
+ room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+
+ for event in auth_chain:
+ # We sign these again because there was a bug where we
+ # incorrectly signed things the first time round
+ if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
@@ -188,13 +253,11 @@ class FederationServer(FederationBase):
self.hs.config.signing_key[0]
)
)
- else:
- raise NotImplementedError("Specify an event")
- defer.returnValue((200, {
+ defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }))
+ })
@defer.inlineCallbacks
@log_function
@@ -268,14 +331,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- defer.returnValue((200, {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }))
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ time_now = self._clock.time_msec()
+ auth_pdus = yield self.handler.on_event_auth(event_id)
+ res = {
+ "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+ }
+ defer.returnValue((200, res))
@defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, event_id):
+ def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
@@ -294,58 +359,41 @@ class FederationServer(FederationBase):
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
- auth_chain = [
- self.event_from_pdu_json(e)
- for e in content["auth_chain"]
- ]
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ auth_chain = [
+ self.event_from_pdu_json(e)
+ for e in content["auth_chain"]
+ ]
+
+ signed_auth = yield self._check_sigs_and_hash_and_fetch(
+ origin, auth_chain, outlier=True
+ )
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True
- )
+ ret = yield self.handler.on_query_auth(
+ origin,
+ event_id,
+ signed_auth,
+ content.get("rejects", []),
+ content.get("missing", []),
+ )
- ret = yield self.handler.on_query_auth(
- origin,
- event_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
-
- time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [
- e.get_pdu_json(time_now)
- for e in ret["auth_chain"]
- ],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
+ time_now = self._clock.time_msec()
+ send_content = {
+ "auth_chain": [
+ e.get_pdu_json(time_now)
+ for e in ret["auth_chain"]
+ ],
+ "rejects": ret.get("rejects", []),
+ "missing": ret.get("missing", []),
+ }
defer.returnValue(
(200, send_content)
)
- @defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
- query = []
- for user_id, device_ids in content.get("device_keys", {}).items():
- if not device_ids:
- query.append((user_id, None))
- else:
- for device_id in device_ids:
- query.append((user_id, device_id))
-
- results = yield self.store.get_e2e_device_keys(query)
-
- json_result = {}
- for user_id, device_keys in results.items():
- for device_id, json_bytes in device_keys.items():
- json_result.setdefault(user_id, {})[device_id] = json.loads(
- json_bytes
- )
-
- defer.returnValue({"device_keys": json_result})
+ return self.on_query_request("client_keys", content)
@defer.inlineCallbacks
@log_function
@@ -371,17 +419,35 @@ class FederationServer(FederationBase):
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
- missing_events = yield self.handler.on_get_missing_events(
- origin, room_id, earliest_events, latest_events, limit, min_depth
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ logger.info(
+ "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+ " limit: %d, min_depth: %d",
+ earliest_events, latest_events, limit, min_depth
+ )
+ missing_events = yield self.handler.on_get_missing_events(
+ origin, room_id, earliest_events, latest_events, limit, min_depth
+ )
- time_now = self._clock.time_msec()
+ if len(missing_events) < 5:
+ logger.info(
+ "Returning %d events: %r", len(missing_events), missing_events
+ )
+ else:
+ logger.info("Returning %d events", len(missing_events))
+
+ time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
@log_function
+ def on_openid_userinfo(self, token):
+ ts_now_ms = self._clock.time_msec()
+ return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
+
+ @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
@@ -470,42 +536,59 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
- latest = yield self.store.get_latest_event_ids_in_room(
- pdu.room_id
- )
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest)
- latest |= seen
-
- missing_events = yield self.get_missing_events(
- origin,
- pdu.room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- )
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for e in missing_events:
- yield self._handle_new_pdu(
- origin,
- e,
- get_missing=False
- )
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ # We recalculate seen, since it may have changed.
+ have_seen = yield self.store.have_events(prevs)
+ seen = set(have_seen.keys())
+
+ if prevs - seen:
+ latest = yield self.store.get_latest_event_ids_in_room(
+ pdu.room_id
+ )
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest)
+ latest |= seen
+
+ logger.info(
+ "Missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
+
+ missing_events = yield self.get_missing_events(
+ origin,
+ pdu.room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ )
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ missing_events.sort(key=lambda x: x.depth)
+
+ for e in missing_events:
+ yield self._handle_new_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
+ logger.info(
+ "Still missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
fetch_state = True
if fetch_state:
@@ -520,12 +603,11 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id,
)
except:
- logger.warn("Failed to get state for event: %s", pdu.event_id)
+ logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,
pdu,
- backfilled=False,
state=state,
auth_chain=auth_chain,
)
|