diff options
author | David Baker <dave@matrix.org> | 2016-08-11 14:09:13 +0100 |
---|---|---|
committer | David Baker <dave@matrix.org> | 2016-08-11 14:09:13 +0100 |
commit | b4ecf0b886c67437901e0af457c5f801ebde9a72 (patch) | |
tree | ef66b0684edcfeb4ad68d20375641f4654393f44 /synapse/handlers/federation.py | |
parent | Include the ts the notif was received at (diff) | |
parent | Merge pull request #1003 from matrix-org/erikj/redaction_prev_content (diff) | |
download | synapse-b4ecf0b886c67437901e0af457c5f801ebde9a72.tar.xz |
Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r-- | synapse/handlers/federation.py | 189 |
1 files changed, 122 insertions, 67 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 648a505e65..ff6bb475b5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -66,10 +66,6 @@ class FederationHandler(BaseHandler): self.hs = hs - self.distributor.observe("user_joined_room", self.user_joined_room) - - self.waiting_for_join_list = {} - self.store = hs.get_datastore() self.replication_layer = hs.get_replication_layer() self.state_handler = hs.get_state_handler() @@ -128,7 +124,7 @@ class FederationHandler(BaseHandler): try: event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) except AuthError as e: raise FederationError( @@ -253,7 +249,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -339,29 +335,58 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - seen_events = yield self.store.have_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - all_events = events + state_events.values() + auth_events.values() required_auth = set( - a_id for event in all_events for a_id, _ in event.auth_events + a_id + for event in events + state_events.values() + auth_events.values() + for a_id, _ in event.auth_events ) - + auth_events.update({ + e_id: event_map[e_id] for e_id in required_auth if e_id in event_map + }) missing_auth = required_auth - set(auth_events) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, + failed_to_fetch = set() + + # Try and fetch any missing auth events from both DB and remote servers. + # We repeatedly do this until we stop finding new auth events. + while missing_auth - failed_to_fetch: + logger.info("Missing auth for backfill: %r", missing_auth) + ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + auth_events.update(ret_events) + + required_auth.update( + a_id for event in ret_events.values() for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + if missing_auth - failed_to_fetch: + logger.info( + "Fetching missing auth for backfill: %r", + missing_auth - failed_to_fetch ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + + results = yield defer.gatherResults( + [ + self.replication_layer.get_pdu( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True + ).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results}) + required_auth.update( + a_id for event in results for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + failed_to_fetch = missing_auth - set(auth_events) + + seen_events = yield self.store.have_events( + set(auth_events.keys()) | set(state_events.keys()) + ) ev_infos = [] for a in auth_events.values(): @@ -374,6 +399,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in a.auth_events + if a_id in auth_events } }) @@ -385,6 +411,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in event_map[e_id].auth_events + if a_id in auth_events } }) @@ -403,7 +430,7 @@ class FederationHandler(BaseHandler): # previous to work out the state. # TODO: We can probably do something more clever here. yield self._handle_new_event( - dest, event + dest, event, backfilled=True, ) defer.returnValue(events) @@ -639,7 +666,7 @@ class FederationHandler(BaseHandler): pass event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) with PreserveLoggingContext(): @@ -690,7 +717,9 @@ class FederationHandler(BaseHandler): logger.warn("Failed to create join %r because %s", event, e) raise e - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) defer.returnValue(event) @@ -920,7 +949,9 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e @@ -989,14 +1020,9 @@ class FederationHandler(BaseHandler): defer.returnValue(None) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): + def get_state_for_pdu(self, room_id, event_id): yield run_on_reactor() - if do_auth: - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -1020,13 +1046,16 @@ class FederationHandler(BaseHandler): res = results.values() for event in res: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) defer.returnValue(res) else: @@ -1064,16 +1093,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( @@ -1083,6 +1113,12 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( + origin, event.room_id, [event] + ) + + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) @@ -1091,15 +1127,6 @@ class FederationHandler(BaseHandler): def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) - @log_function - def user_joined_room(self, user, room_id): - waiters = self.waiting_for_join_list.get( - (user.to_string(), room_id), - [] - ) - while waiters: - waiters.pop().callback(None) - @defer.inlineCallbacks @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None, @@ -1122,11 +1149,12 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( - event_stream_id, max_stream_id - ) + if not backfilled: + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) defer.returnValue((context, event_stream_id, max_stream_id)) @@ -1158,11 +1186,19 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks - def _persist_auth_tree(self, auth_events, state, event): + def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Will attempt to fetch missing auth events. + + Args: + origin (str): Where the events came from + auth_events (list) + state (list) + event (Event) + Returns: 2-tuple of (event_stream_id, max_stream_id) from the persist_event call for `event` @@ -1175,7 +1211,7 @@ class FederationHandler(BaseHandler): event_map = { e.event_id: e - for e in auth_events + for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1184,10 +1220,29 @@ class FederationHandler(BaseHandler): create_event = e break + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id, _ in e.auth_events: + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = yield self.replication_layer.get_pdu( + [origin], + e_id, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] for e_id, _ in e.auth_events + if e_id in event_map } if create_event: auth_for_e[(EventTypes.Create, "")] = create_event @@ -1421,7 +1476,7 @@ class FederationHandler(BaseHandler): local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( |