diff options
author | Mark Haines <mjark@negativecurvature.net> | 2014-11-11 16:40:50 +0000 |
---|---|---|
committer | Mark Haines <mjark@negativecurvature.net> | 2014-11-11 16:40:50 +0000 |
commit | a8ceeec0fd512e287cbf71efff42015787517a5d (patch) | |
tree | 45643674a31b637799e347f2251c72417e685616 /synapse/federation | |
parent | no evil horizontal textarea resizing (diff) | |
parent | Fix bugs which broke federation due to changes in function signatures. (diff) | |
download | synapse-a8ceeec0fd512e287cbf71efff42015787517a5d.tar.xz |
Merge pull request #12 from matrix-org/federation_authorization
Federation authorization
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/pdu_codec.py | 60 | ||||
-rw-r--r-- | synapse/federation/persistence.py | 73 | ||||
-rw-r--r-- | synapse/federation/replication.py | 319 | ||||
-rw-r--r-- | synapse/federation/transport.py | 320 | ||||
-rw-r--r-- | synapse/federation/units.py | 84 |
5 files changed, 414 insertions, 442 deletions
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py index e8180d94fd..52c84efb5b 100644 --- a/synapse/federation/pdu_codec.py +++ b/synapse/federation/pdu_codec.py @@ -18,50 +18,25 @@ from .units import Pdu import copy -def decode_event_id(event_id, server_name): - parts = event_id.split("@") - if len(parts) < 2: - return (event_id, server_name) - else: - return (parts[0], "".join(parts[1:])) - - -def encode_event_id(pdu_id, origin): - return "%s@%s" % (pdu_id, origin) - - class PduCodec(object): def __init__(self, hs): + self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname self.event_factory = hs.get_event_factory() self.clock = hs.get_clock() + self.hs = hs def event_from_pdu(self, pdu): kwargs = {} - kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) - kwargs["room_id"] = pdu.context - kwargs["etype"] = pdu.pdu_type - kwargs["prev_events"] = [ - encode_event_id(p[0], p[1]) for p in pdu.prev_pdus - ] - - if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): - kwargs["prev_state"] = encode_event_id( - pdu.prev_state_id, pdu.prev_state_origin - ) + kwargs["etype"] = pdu.type kwargs.update({ k: v for k, v in pdu.get_full_dict().items() if k not in [ - "pdu_id", - "context", - "pdu_type", - "prev_pdus", - "prev_state_id", - "prev_state_origin", + "type", ] }) @@ -70,33 +45,10 @@ class PduCodec(object): def pdu_from_event(self, event): d = event.get_full_dict() - d["pdu_id"], d["origin"] = decode_event_id( - event.event_id, self.server_name - ) - d["context"] = event.room_id - d["pdu_type"] = event.type - - if hasattr(event, "prev_events"): - d["prev_pdus"] = [ - decode_event_id(e, self.server_name) - for e in event.prev_events - ] - - if hasattr(event, "prev_state"): - d["prev_state_id"], d["prev_state_origin"] = ( - decode_event_id(event.prev_state, self.server_name) - ) - - if hasattr(event, "state_key"): - d["is_state"] = True - kwargs = copy.deepcopy(event.unrecognized_keys) kwargs.update({ k: v for k, v in d.items() - if k not in ["event_id", "room_id", "type", "prev_events"] }) - if "origin_server_ts" not in kwargs: - kwargs["origin_server_ts"] = int(self.clock.time_msec()) - - return Pdu(**kwargs) + pdu = Pdu(**kwargs) + return pdu diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 7043fcc504..73dc844d59 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module. from twisted.internet import defer -from .units import Pdu - from synapse.util.logutils import log_function import json @@ -32,76 +30,6 @@ import logging logger = logging.getLogger(__name__) -class PduActions(object): - """ Defines persistence actions that relate to handling PDUs. - """ - - def __init__(self, datastore): - self.store = datastore - - @log_function - def mark_as_processed(self, pdu): - """ Persist the fact that we have fully processed the given `Pdu` - - Returns: - Deferred - """ - return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin) - - @defer.inlineCallbacks - @log_function - def after_transaction(self, transaction_id, destination, origin): - """ Returns all `Pdu`s that we sent to the given remote home server - after a given transaction id. - - Returns: - Deferred: Results in a list of `Pdu`s - """ - results = yield self.store.get_pdus_after_transaction( - transaction_id, - destination - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def get_all_pdus_from_context(self, context): - results = yield self.store.get_all_pdus_from_context(context) - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def backfill(self, context, pdu_list, limit): - """ For a given list of PDU id and origins return the proceeding - `limit` `Pdu`s in the given `context`. - - Returns: - Deferred: Results in a list of `Pdu`s. - """ - results = yield self.store.get_backfill( - context, pdu_list, limit - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @log_function - def is_new(self, pdu): - """ When we receive a `Pdu` from a remote home server, we want to - figure out whether it is `new`, i.e. it is not some historic PDU that - we haven't seen simply because we haven't backfilled back that far. - - Returns: - Deferred: Results in a `bool` - """ - return self.store.is_pdu_new( - pdu_id=pdu.pdu_id, - origin=pdu.origin, - context=pdu.context, - depth=pdu.depth - ) - - class TransactionActions(object): """ Defines persistence actions that relate to handling Transactions. """ @@ -158,7 +86,6 @@ class TransactionActions(object): transaction.transaction_id, transaction.destination, transaction.origin_server_ts, - [(p["pdu_id"], p["origin"]) for p in transaction.pdus] ) @log_function diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 092411eaf9..5c625ddabf 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -21,7 +21,7 @@ from twisted.internet import defer from .units import Transaction, Pdu, Edu -from .persistence import PduActions, TransactionActions +from .persistence import TransactionActions from synapse.util.logutils import log_function @@ -57,7 +57,7 @@ class ReplicationLayer(object): self.transport_layer.register_request_handler(self) self.store = hs.get_datastore() - self.pdu_actions = PduActions(self.store) + # self.pdu_actions = PduActions(self.store) self.transaction_actions = TransactionActions(self.store) self._transaction_queue = _TransactionQueue( @@ -81,7 +81,7 @@ class ReplicationLayer(object): def register_edu_handler(self, edu_type, handler): if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type)) + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) self.edu_handlers[edu_type] = handler @@ -102,24 +102,17 @@ class ReplicationLayer(object): object to encode as JSON. """ if query_type in self.query_handlers: - raise KeyError("Already have a Query handler for %s" % (query_type)) + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) self.query_handlers[query_type] = handler - @defer.inlineCallbacks @log_function def send_pdu(self, pdu): """Informs the replication layer about a new PDU generated within the home server that should be transmitted to others. - This will fill out various attributes on the PDU object, e.g. the - `prev_pdus` key. - - *Note:* The home server should always call `send_pdu` even if it knows - that it does not need to be replicated to other home servers. This is - in case e.g. someone else joins via a remote home server and then - backfills. - TODO: Figure out when we should actually resolve the deferred. Args: @@ -132,18 +125,15 @@ class ReplicationLayer(object): order = self._order self._order += 1 - logger.debug("[%s] Persisting PDU", pdu.pdu_id) - - # Save *before* trying to send - yield self.store.persist_event(pdu=pdu) - - logger.debug("[%s] Persisted PDU", pdu.pdu_id) - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id) + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) # TODO, add errback, etc. self._transaction_queue.enqueue_pdu(pdu, order) - logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id) + logger.debug( + "[%s] transaction_layer.enqueue_pdu... done", + pdu.event_id + ) @log_function def send_edu(self, destination, edu_type, content): @@ -159,6 +149,11 @@ class ReplicationLayer(object): return defer.succeed(None) @log_function + def send_failure(self, failure, destination): + self._transaction_queue.enqueue_failure(failure, destination) + return defer.succeed(None) + + @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=True): """Sends a federation Query to a remote homeserver of the given type @@ -181,7 +176,7 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def backfill(self, dest, context, limit): + def backfill(self, dest, context, limit, extremities): """Requests some more historic PDUs for the given context from the given destination server. @@ -189,12 +184,12 @@ class ReplicationLayer(object): dest (str): The remote home server to ask. context (str): The context to backfill. limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context Returns: Deferred: Results in the received PDUs. """ - extremities = yield self.store.get_oldest_pdus_in_context(context) - logger.debug("backfill extrem=%s", extremities) # If there are no extremeties then we've (probably) reached the start. @@ -210,13 +205,13 @@ class ReplicationLayer(object): pdus = [Pdu(outlier=False, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu, backfilled=True) + yield self._handle_new_pdu(dest, pdu, backfilled=True) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): + def get_pdu(self, destination, event_id, outlier=False): """Requests the PDU with given origin and ID from the remote home server. @@ -225,7 +220,7 @@ class ReplicationLayer(object): Args: destination (str): Which home server to query pdu_origin (str): The home server that originally sent the pdu. - pdu_id (str) + event_id (str) outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` @@ -234,8 +229,9 @@ class ReplicationLayer(object): Deferred: Results in the requested PDU. """ - transaction_data = yield self.transport_layer.get_pdu( - destination, pdu_origin, pdu_id) + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) transaction = Transaction(**transaction_data) @@ -244,13 +240,13 @@ class ReplicationLayer(object): pdu = None if pdu_list: pdu = pdu_list[0] - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdu) @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context): + def get_state_for_context(self, destination, context, event_id=None): """Requests all of the `current` state PDUs for a given context from a remote home server. @@ -263,29 +259,25 @@ class ReplicationLayer(object): """ transaction_data = yield self.transport_layer.get_context_state( - destination, context) + destination, + context, + event_id=event_id, + ) transaction = Transaction(**transaction_data) pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def on_context_pdus_request(self, context): - pdus = yield self.pdu_actions.get_all_pdus_from_context( - context + def on_backfill_request(self, origin, context, versions, limit): + pdus = yield self.handler.on_backfill_request( + origin, context, versions, limit ) - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) - - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, context, versions, limit): - - pdus = yield self.pdu_actions.backfill(context, versions, limit) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @@ -295,6 +287,10 @@ class ReplicationLayer(object): transaction = Transaction(**transaction_data) for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] if "age" in p: p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) del p["age"] @@ -315,11 +311,15 @@ class ReplicationLayer(object): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(pdu)) + dl.append(self._handle_new_pdu(transaction.origin, pdu)) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu(transaction.origin, edu.edu_type, edu.content) + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) results = yield defer.DeferredList(dl) @@ -347,20 +347,22 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def on_context_state_request(self, context): - results = yield self.store.get_current_state_for_context( - context - ) - - logger.debug("Context returning %d results", len(results)) + def on_context_state_request(self, origin, context, event_id): + if event_id: + pdus = yield self.handler.get_state_for_pdu( + origin, + context, + event_id, + ) + else: + raise NotImplementedError("Specify an event") - pdus = [Pdu.from_pdu_tuple(p) for p in results] defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @log_function - def on_pdu_request(self, pdu_origin, pdu_id): - pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin) + def on_pdu_request(self, origin, event_id): + pdu = yield self._get_persisted_pdu(origin, event_id) if pdu: defer.returnValue( @@ -372,103 +374,191 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function def on_pull_request(self, origin, versions): - transaction_id = max([int(v) for v in versions]) + raise NotImplementedError("Pull transacions not implemented") + + @defer.inlineCallbacks + def on_query_request(self, query_type, args): + if query_type in self.query_handlers: + response = yield self.query_handlers[query_type](args) + defer.returnValue((200, response)) + else: + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type, )) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, context, user_id): + pdu = yield self.handler.on_make_join_request(context, user_id) + defer.returnValue({ + "event": pdu.get_dict(), + }) - response = yield self.pdu_actions.after_transaction( - transaction_id, - origin, - self.server_name + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = Pdu(**content) + ret_pdu = yield self.handler.on_invite_request(origin, pdu) + defer.returnValue( + ( + 200, + { + "event": ret_pdu.get_dict(), + } + ) ) - if not response: - response = [] + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + pdu = Pdu(**content) + res_pdus = yield self.handler.on_send_join_request(origin, pdu) + + defer.returnValue((200, { + "state": [p.get_dict() for p in res_pdus["state"]], + "auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]], + })) + @defer.inlineCallbacks + def on_event_auth(self, origin, context, event_id): + auth_pdus = yield self.handler.on_event_auth(event_id) defer.returnValue( - (200, self._transaction_from_pdus(response).get_dict()) + ( + 200, + { + "auth_chain": [a.get_dict() for a in auth_pdus], + } + ) ) @defer.inlineCallbacks - def on_query_request(self, query_type, args): - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue((404, "No handler for Query type '%s'" - % (query_type) - )) + def make_join(self, destination, context, user_id): + ret = yield self.transport_layer.make_join( + destination=destination, + context=context, + user_id=user_id, + ) + + pdu_dict = ret["event"] + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) @defer.inlineCallbacks + def send_join(self, destination, pdu): + _, content = yield self.transport_layer.send_join( + destination, + pdu.room_id, + pdu.event_id, + pdu.get_dict(), + ) + + logger.debug("Got content: %s", content) + state = [Pdu(outlier=True, **p) for p in content.get("state", [])] + for pdu in state: + yield self._handle_new_pdu(destination, pdu) + + auth_chain = [ + Pdu(outlier=True, **p) for p in content.get("auth_chain", []) + ] + for pdu in auth_chain: + yield self._handle_new_pdu(destination, pdu) + + defer.returnValue(state) + + @defer.inlineCallbacks + def send_invite(self, destination, context, event_id, pdu): + code, content = yield self.transport_layer.send_invite( + destination=destination, + context=context, + event_id=event_id, + content=pdu.get_dict(), + ) + + pdu_dict = content["event"] + + logger.debug("Got response to send_invite: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) + @log_function - def _get_persisted_pdu(self, pdu_id, pdu_origin): + def _get_persisted_pdu(self, origin, event_id): """ Get a PDU from the database with given origin and id. Returns: Deferred: Results in a `Pdu`. """ - pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin) - - defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple)) + return self.handler.get_persisted_pdu(origin, event_id) def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for transmission. """ pdus = [p.get_dict() for p in pdu_list] + time_now = self._clock.time_msec() for p in pdus: - if "age_ts" in pdus: - p["age"] = int(self.clock.time_msec()) - p["age_ts"] - + if "age_ts" in p: + age = time_now - p["age_ts"] + p.setdefault("unsigned", {})["age"] = int(age) + del p["age_ts"] return Transaction( origin=self.server_name, pdus=pdus, - origin_server_ts=int(self._clock.time_msec()), + origin_server_ts=int(time_now), destination=None, ) @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, pdu, backfilled=False): + def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) + existing = yield self._get_persisted_pdu(origin, pdu.event_id) if existing and (not existing.outlier or pdu.outlier): - logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin) + logger.debug("Already seen pdu %s", pdu.event_id) defer.returnValue({}) return + state = None + # Get missing pdus if necessary. - is_new = yield self.pdu_actions.is_new(pdu) - if is_new and not pdu.outlier: + if not pdu.outlier: # We only backfill backwards to the min depth. - min_depth = yield self.store.get_min_depth_for_context(pdu.context) + min_depth = yield self.handler.get_min_depth_for_context( + pdu.room_id + ) if min_depth and pdu.depth > min_depth: - for pdu_id, origin in pdu.prev_pdus: - exists = yield self._get_persisted_pdu(pdu_id, origin) + for event_id, hashes in pdu.prev_events: + exists = yield self._get_persisted_pdu(origin, event_id) if not exists: - logger.debug("Requesting pdu %s %s", pdu_id, origin) + logger.debug("Requesting pdu %s", event_id) try: yield self.get_pdu( pdu.origin, - pdu_id=pdu_id, - pdu_origin=origin + event_id=event_id, ) - logger.debug("Processed pdu %s %s", pdu_id, origin) + logger.debug("Processed pdu %s", event_id) except: # TODO(erikj): Do some more intelligent retries. logger.exception("Failed to get PDU") - - # Persist the Pdu, but don't mark it as processed yet. - yield self.store.persist_event(pdu=pdu) + else: + # We need to get the state at this event, since we have reached + # a backward extremity edge. + state = yield self.get_state_for_context( + origin, pdu.room_id, pdu.event_id, + ) if not backfilled: - ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) + ret = yield self.handler.on_receive_pdu( + pdu, + backfilled=backfilled, + state=state, + ) else: ret = None - yield self.pdu_actions.mark_as_processed(pdu) + # yield self.pdu_actions.mark_as_processed(pdu) defer.returnValue(ret) @@ -476,14 +566,6 @@ class ReplicationLayer(object): return "<ReplicationLayer(%s)>" % self.server_name -class ReplicationHandler(object): - """This defines the methods that the :py:class:`.ReplicationLayer` will - use to communicate with the rest of the home server. - """ - def on_receive_pdu(self, pdu): - raise NotImplementedError("on_receive_pdu") - - class _TransactionQueue(object): """This class makes sure we only have one transaction in flight at a time for a given destination. @@ -509,6 +591,9 @@ class _TransactionQueue(object): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = {} + # destination -> list of tuple(failure, deferred) + self.pending_failures_by_dest = {} + # HACK to get unique tx id self._next_txn_id = int(self._clock.time_msec()) @@ -562,6 +647,18 @@ class _TransactionQueue(object): return deferred @defer.inlineCallbacks + def enqueue_failure(self, failure, destination): + deferred = defer.Deferred() + + self.pending_failures_by_dest.setdefault( + destination, [] + ).append( + (failure, deferred) + ) + + yield deferred + + @defer.inlineCallbacks @log_function def _attempt_new_transaction(self, destination): if destination in self.pending_transactions: @@ -570,8 +667,9 @@ class _TransactionQueue(object): # list of (pending_pdu, deferred, order) pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - if not pending_pdus and not pending_edus: + if not pending_pdus and not pending_edus and not pending_failures: return logger.debug("TX [%s] Attempting new transaction", destination) @@ -581,7 +679,11 @@ class _TransactionQueue(object): pdus = [x[0] for x in pending_pdus] edus = [x[0] for x in pending_edus] - deferreds = [x[1] for x in pending_pdus + pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] try: self.pending_transactions[destination] = 1 @@ -589,12 +691,13 @@ class _TransactionQueue(object): logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=self._clock.time_msec(), + origin_server_ts=int(self._clock.time_msec()), transaction_id=str(self._next_txn_id), origin=self.server_name, destination=destination, pdus=pdus, edus=edus, + pdu_failures=failures, ) self._next_txn_id += 1 @@ -614,7 +717,9 @@ class _TransactionQueue(object): if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: - p["age"] = now - int(p["age_ts"]) + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] return data code, response = yield self.transport_layer.send_transaction( diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index e7517cac4d..95c40c6c1b 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -72,7 +72,7 @@ class TransportLayer(object): self.received_handler = None @log_function - def get_context_state(self, destination, context): + def get_context_state(self, destination, context, event_id=None): """ Requests all state for a given context (i.e. room) from the given server. @@ -89,54 +89,62 @@ class TransportLayer(object): subpath = "/state/%s/" % context - return self._do_request_for_transaction(destination, subpath) + args = {} + if event_id: + args["event_id"] = event_id + + return self._do_request_for_transaction( + destination, subpath, args=args + ) @log_function - def get_pdu(self, destination, pdu_origin, pdu_id): + def get_event(self, destination, event_id): """ Requests the pdu with give id and origin from the given server. Args: destination (str): The host name of the remote home server we want to get the state from. - pdu_origin (str): The home server which created the PDU. - pdu_id (str): The id of the PDU being requested. + event_id (str): The id of the event being requested. Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s", - destination, pdu_origin, pdu_id) + logger.debug("get_pdu dest=%s, event_id=%s", + destination, event_id) - subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id) + subpath = "/event/%s/" % (event_id, ) return self._do_request_for_transaction(destination, subpath) @log_function - def backfill(self, dest, context, pdu_tuples, limit): + def backfill(self, dest, context, event_tuples, limit): """ Requests `limit` previous PDUs in a given context before list of PDUs. Args: dest (str) context (str) - pdu_tuples (list) + event_tuples (list) limt (int) Returns: Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s", - dest, context, repr(pdu_tuples), str(limit) + "backfill dest=%s, context=%s, event_tuples=%s, limit=%s", + dest, context, repr(event_tuples), str(limit) ) - if not pdu_tuples: + if not event_tuples: + # TODO: raise? return - subpath = "/backfill/%s/" % context + subpath = "/backfill/%s/" % (context,) - args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} - args["limit"] = limit + args = { + "v": event_tuples, + "limit": limit, + } return self._do_request_for_transaction( dest, @@ -198,6 +206,72 @@ class TransportLayer(object): defer.returnValue(response) @defer.inlineCallbacks + @log_function + def make_join(self, destination, context, user_id, retry_on_dns_fail=True): + path = PREFIX + "/make_join/%s/%s" % (context, user_id,) + + response = yield self.client.get_json( + destination=destination, + path=path, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_join(self, destination, context, event_id, content): + path = PREFIX + "/send_join/%s/%s" % ( + context, + event_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, context, event_id, content): + path = PREFIX + "/invite/%s/%s" % ( + context, + event_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, context, event_id): + path = PREFIX + "/event_auth/%s/%s" % ( + context, + event_id, + ) + + response = yield self.client.get_json( + destination=destination, + path=path, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks def _authenticate_request(self, request): json_request = { "method": request.method, @@ -210,7 +284,7 @@ class TransportLayer(object): origin = None if request.method == "PUT": - #TODO: Handle other method types? other content types? + # TODO: Handle other method types? other content types? try: content_bytes = request.content.read() content = json.loads(content_bytes) @@ -222,11 +296,13 @@ class TransportLayer(object): try: params = auth.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) + def strip_quotes(value): if value.startswith("\""): return value[1:-1] else: return value + origin = strip_quotes(param_dict["origin"]) key = strip_quotes(param_dict["key"]) sig = strip_quotes(param_dict["sig"]) @@ -247,7 +323,7 @@ class TransportLayer(object): if auth.startswith("X-Matrix"): (origin, key, sig) = parse_auth_header(auth) json_request["origin"] = origin - json_request["signatures"].setdefault(origin,{})[key] = sig + json_request["signatures"].setdefault(origin, {})[key] = sig if not json_request["signatures"]: raise SynapseError( @@ -313,10 +389,10 @@ class TransportLayer(object): # data_id pair. self.server.register_path( "GET", - re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"), + re.compile("^" + PREFIX + "/event/([^/]*)/$"), self._with_authentication( - lambda origin, content, query, pdu_origin, pdu_id: - handler.on_pdu_request(pdu_origin, pdu_id) + lambda origin, content, query, event_id: + handler.on_pdu_request(origin, event_id) ) ) @@ -326,7 +402,11 @@ class TransportLayer(object): re.compile("^" + PREFIX + "/state/([^/]*)/$"), self._with_authentication( lambda origin, content, query, context: - handler.on_context_state_request(context) + handler.on_context_state_request( + origin, + context, + query.get("event_id", [None])[0], + ) ) ) @@ -336,28 +416,63 @@ class TransportLayer(object): self._with_authentication( lambda origin, content, query, context: self._on_backfill_request( - context, query["v"], query["limit"] + origin, context, query["v"], query["limit"] ) ) ) + # This is when we receive a server-server Query self.server.register_path( "GET", - re.compile("^" + PREFIX + "/context/([^/]*)/$"), + re.compile("^" + PREFIX + "/query/([^/]*)$"), self._with_authentication( - lambda origin, content, query, context: - handler.on_context_pdus_request(context) + lambda origin, content, query, query_type: + handler.on_query_request( + query_type, {k: v[0] for k, v in query.items()} + ) ) ) - # This is when we receive a server-server Query self.server.register_path( "GET", - re.compile("^" + PREFIX + "/query/([^/]*)$"), + re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), self._with_authentication( - lambda origin, content, query, query_type: - handler.on_query_request( - query_type, {k: v[0] for k, v in query.items()} + lambda origin, content, query, context, user_id: + self._on_make_join_request( + origin, content, query, context, user_id + ) + ) + ) + + self.server.register_path( + "GET", + re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + handler.on_event_auth( + origin, context, event_id, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_send_join_request( + origin, content, query, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_invite_request( + origin, content, query, ) ) ) @@ -402,7 +517,8 @@ class TransportLayer(object): return try: - code, response = yield self.received_handler.on_incoming_transaction( + handler = self.received_handler + code, response = yield handler.on_incoming_transaction( transaction_data ) except: @@ -440,7 +556,7 @@ class TransportLayer(object): defer.returnValue(data) @log_function - def _on_backfill_request(self, context, v_list, limits): + def _on_backfill_request(self, origin, context, v_list, limits): if not limits: return defer.succeed( (400, {"error": "Did not include limit param"}) @@ -448,124 +564,34 @@ class TransportLayer(object): limit = int(limits[-1]) - versions = [v.split(",", 1) for v in v_list] + versions = v_list return self.request_handler.on_backfill_request( - context, versions, limit) - - -class TransportReceivedHandler(object): - """ Callbacks used when we receive a transaction - """ - def on_incoming_transaction(self, transaction): - """ Called on PUT /send/<transaction_id>, or on response to a request - that we sent (e.g. a backfill request) - - Args: - transaction (synapse.transaction.Transaction): The transaction that - was sent to us. - - Returns: - twisted.internet.defer.Deferred: A deferred that gets fired when - the transaction has finished being processed. - - The result should be a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - -class TransportRequestHandler(object): - """ Handlers used when someone want's data from us - """ - def on_pull_request(self, versions): - """ Called on GET /pull/?v=... - - This is hit when a remote home server wants to get all data - after a given transaction. Mainly used when a home server comes back - online and wants to get everything it has missed. - - Args: - versions (list): A list of transaction_ids that should be used to - determine what PDUs the remote side have not yet seen. - - Returns: - Deferred: Resultsin a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_pdu_request(self, pdu_origin, pdu_id): - """ Called on GET /pdu/<pdu_origin>/<pdu_id>/ - - Someone wants a particular PDU. This PDU may or may not have originated - from us. - - Args: - pdu_origin (str) - pdu_id (str) - - Returns: - Deferred: Resultsin a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_context_state_request(self, context): - """ Called on GET /state/<context>/ - - Gets hit when someone wants all the *current* state for a given - contexts. - - Args: - context (str): The name of the context that we're interested in. - - Returns: - twisted.internet.defer.Deferred: A deferred that gets fired when - the transaction has finished being processed. - - The result should be a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_backfill_request(self, context, versions, limit): - """ Called on GET /backfill/<context>/?v=...&limit=... + origin, context, versions, limit + ) - Gets hit when we want to backfill backwards on a given context from - the given point. + @defer.inlineCallbacks + @log_function + def _on_make_join_request(self, origin, content, query, context, user_id): + content = yield self.request_handler.on_make_join_request( + context, user_id, + ) + defer.returnValue((200, content)) - Args: - context (str): The context to backfill - versions (list): A list of 2-tuples representing where to backfill - from, in the form `(pdu_id, origin)` - limit (int): How many pdus to return. + @defer.inlineCallbacks + @log_function + def _on_send_join_request(self, origin, content, query): + content = yield self.request_handler.on_send_join_request( + origin, content, + ) - Returns: - Deferred: Results in a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. + defer.returnValue((200, content)) - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass + @defer.inlineCallbacks + @log_function + def _on_invite_request(self, origin, content, query): + content = yield self.request_handler.on_invite_request( + origin, content, + ) - def on_query_request(self): - """ Called on a GET /query/<query_type> request. """ + defer.returnValue((200, content)) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b2fb964180..f4e7b62bd9 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -20,8 +20,6 @@ server protocol. from synapse.util.jsonobject import JsonEncodedObject import logging -import json -import copy logger = logging.getLogger(__name__) @@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject): A Pdu can be classified as "state". For a given context, we can efficiently retrieve all state pdu's that haven't been clobbered. Clobbering is done - via a unique constraint on the tuple (context, pdu_type, state_key). A pdu + via a unique constraint on the tuple (context, type, state_key). A pdu is a state pdu if `is_state` is True. Example pdu:: { - "pdu_id": "78c", + "event_id": "$78c:example.com", "origin_server_ts": 1404835423000, "origin": "bar", "prev_ids": [ @@ -52,24 +50,21 @@ class Pdu(JsonEncodedObject): """ valid_keys = [ - "pdu_id", - "context", + "event_id", + "room_id", "origin", "origin_server_ts", - "pdu_type", + "type", "destinations", - "transaction_id", - "prev_pdus", + "prev_events", "depth", "content", - "outlier", - "is_state", # Below this are keys valid only for State Pdus. - "state_key", - "power_level", - "prev_state_id", - "prev_state_origin", - "required_power_level", + "hashes", "user_id", + "auth_events", + "signatures", # Below this are keys valid only for State Pdus. + "state_key", + "prev_state", ] internal_keys = [ @@ -79,61 +74,28 @@ class Pdu(JsonEncodedObject): ] required_keys = [ - "pdu_id", - "context", + "event_id", + "room_id", "origin", "origin_server_ts", - "pdu_type", + "type", "content", ] # TODO: We need to make this properly load content rather than # just leaving it as a dict. (OR DO WE?!) - def __init__(self, destinations=[], is_state=False, prev_pdus=[], - outlier=False, **kwargs): - if is_state: - for required_key in ["state_key"]: - if required_key not in kwargs: - raise RuntimeError("Key %s is required" % required_key) - + def __init__(self, destinations=[], prev_events=[], + outlier=False, hashes={}, signatures={}, **kwargs): super(Pdu, self).__init__( destinations=destinations, - is_state=is_state, - prev_pdus=prev_pdus, + prev_events=prev_events, outlier=outlier, + hashes=hashes, + signatures=signatures, **kwargs ) - @classmethod - def from_pdu_tuple(cls, pdu_tuple): - """ Converts a PduTuple to a Pdu - - Args: - pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to - convert - - Returns: - Pdu - """ - if pdu_tuple: - d = copy.copy(pdu_tuple.pdu_entry._asdict()) - d["origin_server_ts"] = d.pop("ts") - - d["content"] = json.loads(d["content_json"]) - del d["content_json"] - - args = {f: d[f] for f in cls.valid_keys if f in d} - if "unrecognized_keys" in d and d["unrecognized_keys"]: - args.update(json.loads(d["unrecognized_keys"])) - - return Pdu( - prev_pdus=pdu_tuple.prev_pdu_list, - **args - ) - else: - return None - def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) @@ -193,6 +155,7 @@ class Transaction(JsonEncodedObject): "edus", "transaction_id", "destination", + "pdu_failures", ] internal_keys = [ @@ -229,7 +192,9 @@ class Transaction(JsonEncodedObject): transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") + raise KeyError( + "Require 'origin_server_ts' to construct a Transaction" + ) if "transaction_id" not in kwargs: raise KeyError( "Require 'transaction_id' to construct a Transaction" @@ -241,6 +206,3 @@ class Transaction(JsonEncodedObject): kwargs["pdus"] = [p.get_dict() for p in pdus] return Transaction(**kwargs) - - - |