diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4344a2bd52..a7ea8fb98f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
+from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -217,17 +218,28 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
- prev_state = context.current_state.get((event.type, event.state_key))
- if not prev_state or prev_state.membership != Membership.JOIN:
- # Only fire user_joined_room if the user has acutally
- # joined the room. Don't bother if the user is just
- # changing their profile info.
+ # Only fire user_joined_room if the user has acutally
+ # joined the room. Don't bother if the user is just
+ # changing their profile info.
+ newly_joined = True
+ prev_state_id = context.current_state_ids.get(
+ (event.type, event.state_key)
+ )
+ if prev_state_id:
+ prev_state = yield self.store.get_event(
+ prev_state_id, allow_none=True,
+ )
+ if prev_state and prev_state.membership == Membership.JOIN:
+ newly_joined = False
+
+ if newly_joined:
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
+ @measure_func("_filter_events_for_server")
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
- event_to_state = yield self.store.get_state_for_events(
+ event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
@@ -235,6 +247,30 @@ class FederationHandler(BaseHandler):
)
)
+ # We only want to pull out member events that correspond to the
+ # server's domain.
+
+ def check_match(id):
+ try:
+ return server_name == get_domain_from_id(id)
+ except:
+ return False
+
+ event_map = yield self.store.get_events([
+ e_id for key_to_eid in event_to_state_ids.values()
+ for key, e_id in key_to_eid
+ if key[0] != EventTypes.Member or check_match(key[1])
+ ])
+
+ event_to_state = {
+ e_id: {
+ key: event_map[inner_e_id]
+ for key, inner_e_id in key_to_eid.items()
+ if inner_e_id in event_map
+ }
+ for e_id, key_to_eid in event_to_state_ids.items()
+ }
+
def redact_disallowed(event, state):
if not state:
return event
@@ -562,6 +598,18 @@ class FederationHandler(BaseHandler):
]))
states = dict(zip(event_ids, [s[1] for s in states]))
+ state_map = yield self.store.get_events(
+ [e_id for ids in states.values() for e_id in ids],
+ get_prev_content=False
+ )
+ states = {
+ key: {
+ k: state_map[e_id]
+ for k, e_id in state_dict.items()
+ if e_id in state_map
+ } for key, state_dict in states.items()
+ }
+
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
@@ -724,7 +772,7 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+ yield self.auth.check_from_context(event, context, do_sig_check=False)
defer.returnValue(event)
@@ -772,18 +820,11 @@ class FederationHandler(BaseHandler):
new_pdu = event
- destinations = set()
-
- for k, s in context.current_state.items():
- try:
- if k[0] == EventTypes.Member:
- if s.content["membership"] == Membership.JOIN:
- destinations.add(get_domain_from_id(s.state_key))
- except:
- logger.warn(
- "Failed to get destination from event %s", s.event_id
- )
-
+ message_handler = self.hs.get_handlers().message_handler
+ destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+ context
+ )
+ destinations = set(destinations)
destinations.discard(origin)
logger.debug(
@@ -794,13 +835,15 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations)
- state_ids = [e.event_id for e in context.current_state.values()]
+ state_ids = context.current_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
+ state = yield self.store.get_events(context.current_state_ids.values())
+
defer.returnValue({
- "state": context.current_state.values(),
+ "state": state.values(),
"auth_chain": auth_chain,
})
@@ -956,7 +999,7 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+ yield self.auth.check_from_context(event, context, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -1000,18 +1043,11 @@ class FederationHandler(BaseHandler):
new_pdu = event
- destinations = set()
-
- for k, s in context.current_state.items():
- try:
- if k[0] == EventTypes.Member:
- if s.content["membership"] == Membership.LEAVE:
- destinations.add(get_domain_from_id(s.state_key))
- except:
- logger.warn(
- "Failed to get destination from event %s", s.event_id
- )
-
+ message_handler = self.hs.get_handlers().message_handler
+ destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+ context
+ )
+ destinations = set(destinations)
destinations.discard(origin)
logger.debug(
@@ -1296,7 +1332,13 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
- auth_events = context.current_state
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, context.current_state_ids, for_verification=True,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.values()
+ }
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -1322,8 +1364,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess:
- full_context = yield self.store.get_current_state(room_id=event.room_id)
- yield self.maybe_kick_guest_users(event, full_context)
+ yield self.maybe_kick_guest_users(event)
defer.returnValue(context)
@@ -1494,7 +1535,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- context.current_state.update(auth_events)
+ context.current_state_ids.update({
+ k: a.event_id for k, a in auth_events.items()
+ })
context.state_group = None
if different_auth and not event.internal_metadata.is_outlier():
@@ -1516,8 +1559,8 @@ class FederationHandler(BaseHandler):
if do_resolution:
# 1. Get what we think is the auth chain.
- auth_ids = self.auth.compute_auth_events(
- event, context.current_state
+ auth_ids = yield self.auth.compute_auth_events(
+ event, context.current_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@@ -1573,7 +1616,9 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
- context.current_state.update(auth_events)
+ context.current_state_ids.update({
+ k: a.event_id for k, a in auth_events.items()
+ })
context.state_group = None
try:
@@ -1760,12 +1805,12 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, context.current_state)
+ yield self.auth.check_from_context(event, context)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, auth_events=context.current_state)
+ yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context)
else:
@@ -1791,11 +1836,11 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, auth_events=context.current_state)
+ self.auth.check_from_context(event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, auth_events=context.current_state)
+ yield self._check_signature(event, context)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
@@ -1809,7 +1854,12 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
- original_invite = context.current_state.get(key)
+ original_invite = None
+ original_invite_id = context.current_state_ids.get(key)
+ if original_invite_id:
+ original_invite = yield self.store.get_event(
+ original_invite_id, allow_none=True
+ )
if not original_invite:
logger.info(
"Could not find invite event for third_party_invite - "
@@ -1826,13 +1876,13 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context))
@defer.inlineCallbacks
- def _check_signature(self, event, auth_events):
+ def _check_signature(self, event, context):
"""
Checks that the signature in the event is consistent with its invite.
Args:
event (Event): The m.room.member event to check
- auth_events (dict<(event type, state_key), event>):
+ context (EventContext):
Raises:
AuthError: if signature didn't match any keys, or key has been
@@ -1843,10 +1893,14 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- invite_event = auth_events.get(
+ invite_event_id = context.current_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)
+ invite_event = None
+ if invite_event_id:
+ invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+
if not invite_event:
raise AuthError(403, "Could not find invite")
|