diff options
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/pdu_codec.py | 49 | ||||
-rw-r--r-- | synapse/federation/replication.py | 119 | ||||
-rw-r--r-- | synapse/federation/transport.py | 139 | ||||
-rw-r--r-- | synapse/federation/units.py | 36 |
4 files changed, 289 insertions, 54 deletions
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py index e8180d94fd..6d31286290 100644 --- a/synapse/federation/pdu_codec.py +++ b/synapse/federation/pdu_codec.py @@ -14,41 +14,43 @@ # limitations under the License. from .units import Pdu +from synapse.crypto.event_signing import ( + add_event_pdu_content_hash, sign_event_pdu +) +from synapse.types import EventID import copy -def decode_event_id(event_id, server_name): - parts = event_id.split("@") - if len(parts) < 2: - return (event_id, server_name) - else: - return (parts[0], "".join(parts[1:])) - - -def encode_event_id(pdu_id, origin): - return "%s@%s" % (pdu_id, origin) - - class PduCodec(object): def __init__(self, hs): + self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname self.event_factory = hs.get_event_factory() self.clock = hs.get_clock() + self.hs = hs + + def encode_event_id(self, local, domain): + return EventID.create(local, domain, self.hs).to_string() + + def decode_event_id(self, event_id): + e_id = self.hs.parse_eventid(event_id) + return e_id.localpart, e_id.domain def event_from_pdu(self, pdu): kwargs = {} - kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) + kwargs["event_id"] = self.encode_event_id(pdu.pdu_id, pdu.origin) kwargs["room_id"] = pdu.context kwargs["etype"] = pdu.pdu_type kwargs["prev_events"] = [ - encode_event_id(p[0], p[1]) for p in pdu.prev_pdus + (self.encode_event_id(i, o), s) + for i, o, s in pdu.prev_pdus ] if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): - kwargs["prev_state"] = encode_event_id( + kwargs["prev_state"] = self.encode_event_id( pdu.prev_state_id, pdu.prev_state_origin ) @@ -70,21 +72,24 @@ class PduCodec(object): def pdu_from_event(self, event): d = event.get_full_dict() - d["pdu_id"], d["origin"] = decode_event_id( - event.event_id, self.server_name + d["pdu_id"], d["origin"] = self.decode_event_id( + event.event_id ) d["context"] = event.room_id d["pdu_type"] = event.type if hasattr(event, "prev_events"): + def f(e, s): + i, o = self.decode_event_id(e) + return i, o, s d["prev_pdus"] = [ - decode_event_id(e, self.server_name) - for e in event.prev_events + f(e, s) + for e, s in event.prev_events ] if hasattr(event, "prev_state"): d["prev_state_id"], d["prev_state_origin"] = ( - decode_event_id(event.prev_state, self.server_name) + self.decode_event_id(event.prev_state) ) if hasattr(event, "state_key"): @@ -99,4 +104,6 @@ class PduCodec(object): if "origin_server_ts" not in kwargs: kwargs["origin_server_ts"] = int(self.clock.time_msec()) - return Pdu(**kwargs) + pdu = Pdu(**kwargs) + pdu = add_event_pdu_content_hash(pdu) + return sign_event_pdu(pdu, self.server_name, self.signing_key) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 092411eaf9..000a3081c2 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -244,13 +244,14 @@ class ReplicationLayer(object): pdu = None if pdu_list: pdu = pdu_list[0] - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdu) @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context): + def get_state_for_context(self, destination, context, pdu_id=None, + pdu_origin=None): """Requests all of the `current` state PDUs for a given context from a remote home server. @@ -263,13 +264,14 @@ class ReplicationLayer(object): """ transaction_data = yield self.transport_layer.get_context_state( - destination, context) + destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin, + ) transaction = Transaction(**transaction_data) pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdus) @@ -295,6 +297,10 @@ class ReplicationLayer(object): transaction = Transaction(**transaction_data) for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] if "age" in p: p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) del p["age"] @@ -315,7 +321,7 @@ class ReplicationLayer(object): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(pdu)) + dl.append(self._handle_new_pdu(transaction.origin, pdu)) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -347,14 +353,19 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def on_context_state_request(self, context): - results = yield self.store.get_current_state_for_context( - context - ) + def on_context_state_request(self, context, pdu_id, pdu_origin): + if pdu_id and pdu_origin: + pdus = yield self.handler.get_state_for_pdu( + pdu_id, pdu_origin + ) + else: + results = yield self.store.get_current_state_for_context( + context + ) + pdus = [Pdu.from_pdu_tuple(p) for p in results] - logger.debug("Context returning %d results", len(results)) + logger.debug("Context returning %d results", len(pdus)) - pdus = [Pdu.from_pdu_tuple(p) for p in results] defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @@ -393,9 +404,55 @@ class ReplicationLayer(object): response = yield self.query_handlers[query_type](args) defer.returnValue((200, response)) else: - defer.returnValue((404, "No handler for Query type '%s'" - % (query_type) - )) + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type, )) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, context, user_id): + pdu = yield self.handler.on_make_join_request(context, user_id) + defer.returnValue(pdu.get_dict()) + + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = Pdu(**content) + ret_pdu = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, ret_pdu.get_dict())) + + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + pdu = Pdu(**content) + state = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, self._transaction_from_pdus(state).get_dict())) + + @defer.inlineCallbacks + def make_join(self, destination, context, user_id): + pdu_dict = yield self.transport_layer.make_join( + destination=destination, + context=context, + user_id=user_id, + ) + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) + + @defer.inlineCallbacks + def send_join(self, destination, pdu): + _, content = yield self.transport_layer.send_join( + destination, + pdu.context, + pdu.pdu_id, + pdu.origin, + pdu.get_dict(), + ) + + logger.debug("Got content: %s", content) + pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])] + for pdu in pdus: + yield self._handle_new_pdu(destination, pdu) + + defer.returnValue(pdus) @defer.inlineCallbacks @log_function @@ -414,20 +471,22 @@ class ReplicationLayer(object): transmission. """ pdus = [p.get_dict() for p in pdu_list] + time_now = self._clock.time_msec() for p in pdus: - if "age_ts" in pdus: - p["age"] = int(self.clock.time_msec()) - p["age_ts"] - + if "age_ts" in p: + age = time_now - p["age_ts"] + p.setdefault("unsigned", {})["age"] = int(age) + del p["age_ts"] return Transaction( origin=self.server_name, pdus=pdus, - origin_server_ts=int(self._clock.time_msec()), + origin_server_ts=int(time_now), destination=None, ) @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, pdu, backfilled=False): + def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) @@ -436,6 +495,8 @@ class ReplicationLayer(object): defer.returnValue({}) return + state = None + # Get missing pdus if necessary. is_new = yield self.pdu_actions.is_new(pdu) if is_new and not pdu.outlier: @@ -443,7 +504,7 @@ class ReplicationLayer(object): min_depth = yield self.store.get_min_depth_for_context(pdu.context) if min_depth and pdu.depth > min_depth: - for pdu_id, origin in pdu.prev_pdus: + for pdu_id, origin, hashes in pdu.prev_pdus: exists = yield self._get_persisted_pdu(pdu_id, origin) if not exists: @@ -459,12 +520,22 @@ class ReplicationLayer(object): except: # TODO(erikj): Do some more intelligent retries. logger.exception("Failed to get PDU") + else: + # We need to get the state at this event, since we have reached + # a backward extremity edge. + state = yield self.get_state_for_context( + origin, pdu.context, pdu.pdu_id, pdu.origin, + ) # Persist the Pdu, but don't mark it as processed yet. yield self.store.persist_event(pdu=pdu) if not backfilled: - ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) + ret = yield self.handler.on_receive_pdu( + pdu, + backfilled=backfilled, + state=state, + ) else: ret = None @@ -589,7 +660,7 @@ class _TransactionQueue(object): logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=self._clock.time_msec(), + origin_server_ts=int(self._clock.time_msec()), transaction_id=str(self._next_txn_id), origin=self.server_name, destination=destination, @@ -614,7 +685,9 @@ class _TransactionQueue(object): if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: - p["age"] = now - int(p["age_ts"]) + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] return data code, response = yield self.transport_layer.send_transaction( diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index e7517cac4d..7f01b4faaf 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -72,7 +72,8 @@ class TransportLayer(object): self.received_handler = None @log_function - def get_context_state(self, destination, context): + def get_context_state(self, destination, context, pdu_id=None, + pdu_origin=None): """ Requests all state for a given context (i.e. room) from the given server. @@ -89,7 +90,14 @@ class TransportLayer(object): subpath = "/state/%s/" % context - return self._do_request_for_transaction(destination, subpath) + args = {} + if pdu_id and pdu_origin: + args["pdu_id"] = pdu_id + args["pdu_origin"] = pdu_origin + + return self._do_request_for_transaction( + destination, subpath, args=args + ) @log_function def get_pdu(self, destination, pdu_origin, pdu_id): @@ -135,8 +143,10 @@ class TransportLayer(object): subpath = "/backfill/%s/" % context - args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} - args["limit"] = limit + args = { + "v": ["%s,%s" % (i, o) for i, o in pdu_tuples], + "limit": limit, + } return self._do_request_for_transaction( dest, @@ -198,6 +208,59 @@ class TransportLayer(object): defer.returnValue(response) @defer.inlineCallbacks + @log_function + def make_join(self, destination, context, user_id, retry_on_dns_fail=True): + path = PREFIX + "/make_join/%s/%s" % (context, user_id,) + + response = yield self.client.get_json( + destination=destination, + path=path, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_join(self, destination, context, pdu_id, origin, content): + path = PREFIX + "/send_join/%s/%s/%s" % ( + context, + origin, + pdu_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, context, pdu_id, origin, content): + path = PREFIX + "/invite/%s/%s/%s" % ( + context, + origin, + pdu_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks def _authenticate_request(self, request): json_request = { "method": request.method, @@ -326,7 +389,11 @@ class TransportLayer(object): re.compile("^" + PREFIX + "/state/([^/]*)/$"), self._with_authentication( lambda origin, content, query, context: - handler.on_context_state_request(context) + handler.on_context_state_request( + context, + query.get("pdu_id", [None])[0], + query.get("pdu_origin", [None])[0] + ) ) ) @@ -362,6 +429,39 @@ class TransportLayer(object): ) ) + self.server.register_path( + "GET", + re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, user_id: + self._on_make_join_request( + origin, content, query, context, user_id + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, pdu_origin, pdu_id: + self._on_send_join_request( + origin, content, query, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, pdu_origin, pdu_id: + self._on_invite_request( + origin, content, query, + ) + ) + ) + @defer.inlineCallbacks @log_function def _on_send_request(self, origin, content, query, transaction_id): @@ -451,7 +551,34 @@ class TransportLayer(object): versions = [v.split(",", 1) for v in v_list] return self.request_handler.on_backfill_request( - context, versions, limit) + context, versions, limit + ) + + @defer.inlineCallbacks + @log_function + def _on_make_join_request(self, origin, content, query, context, user_id): + content = yield self.request_handler.on_make_join_request( + context, user_id, + ) + defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_send_join_request(self, origin, content, query): + content = yield self.request_handler.on_send_join_request( + origin, content, + ) + + defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_invite_request(self, origin, content, query): + content = yield self.request_handler.on_invite_request( + origin, content, + ) + + defer.returnValue((200, content)) class TransportReceivedHandler(object): diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b2fb964180..adc3385644 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -18,6 +18,7 @@ server protocol. """ from synapse.util.jsonobject import JsonEncodedObject +from syutil.base64util import encode_base64 import logging import json @@ -63,9 +64,10 @@ class Pdu(JsonEncodedObject): "depth", "content", "outlier", + "hashes", + "signatures", "is_state", # Below this are keys valid only for State Pdus. "state_key", - "power_level", "prev_state_id", "prev_state_origin", "required_power_level", @@ -91,7 +93,7 @@ class Pdu(JsonEncodedObject): # just leaving it as a dict. (OR DO WE?!) def __init__(self, destinations=[], is_state=False, prev_pdus=[], - outlier=False, **kwargs): + outlier=False, hashes={}, signatures={}, **kwargs): if is_state: for required_key in ["state_key"]: if required_key not in kwargs: @@ -99,9 +101,11 @@ class Pdu(JsonEncodedObject): super(Pdu, self).__init__( destinations=destinations, - is_state=is_state, + is_state=bool(is_state), prev_pdus=prev_pdus, outlier=outlier, + hashes=hashes, + signatures=signatures, **kwargs ) @@ -120,6 +124,10 @@ class Pdu(JsonEncodedObject): d = copy.copy(pdu_tuple.pdu_entry._asdict()) d["origin_server_ts"] = d.pop("ts") + for k in d.keys(): + if d[k] is None: + del d[k] + d["content"] = json.loads(d["content_json"]) del d["content_json"] @@ -127,8 +135,28 @@ class Pdu(JsonEncodedObject): if "unrecognized_keys" in d and d["unrecognized_keys"]: args.update(json.loads(d["unrecognized_keys"])) + hashes = { + alg: encode_base64(hsh) + for alg, hsh in pdu_tuple.hashes.items() + } + + signatures = { + kid: encode_base64(sig) + for kid, sig in pdu_tuple.signatures.items() + } + + prev_pdus = [] + for prev_pdu in pdu_tuple.prev_pdu_list: + prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {}) + prev_hashes = { + alg: encode_base64(hsh) for alg, hsh in prev_hashes.items() + } + prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes)) + return Pdu( - prev_pdus=pdu_tuple.prev_pdu_list, + prev_pdus=prev_pdus, + hashes=hashes, + signatures=signatures, **args ) else: |