diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4c3cd9d12e..e2f4387f60 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
- prev_state = self.deduplicate_state_event(event, context)
+ prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
@@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
+ @defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
@@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_event = context.current_state.get((event.type, event.state_key))
+ prev_event_id = context.current_state_ids.get((event.type, event.state_key))
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ if not prev_event:
+ return
+
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
- return prev_event
- return None
+ defer.returnValue(prev_event)
+ return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
@@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
logger.debug(
"Created event %s with current state: %s",
- event.event_id, context.current_state,
+ event.event_id, context.current_state_ids,
)
defer.returnValue(
@@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
self.ratelimit(requester)
try:
- self.auth.check(event, auth_events=context.current_state)
+ yield self.auth.check_from_context(event, context)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
- yield self.maybe_kick_guest_users(event, context.current_state.values())
+ yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
@@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
e.sender == event.sender
)
+ state_to_include_ids = [
+ e_id
+ for k, e_id in context.current_state_ids.items()
+ if k[0] in self.hs.config.room_invite_state_types
+ or k[0] == EventTypes.Member and k[1] == event.sender
+ ]
+
+ state_to_include = yield self.store.get_events(state_to_include_ids)
+
event.unsigned["invite_room_state"] = [
{
"type": e.type,
@@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
"content": e.content,
"sender": e.sender,
}
- for k, e in context.current_state.items()
- if e.type in self.hs.config.room_invite_state_types
- or is_inviter_member_event(e)
+ for e in state_to_include.values()
]
invitee = UserID.from_string(event.state_key)
@@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Redaction:
- if self.auth.check_redaction(event, auth_events=context.current_state):
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, context.current_state_ids, for_verification=True,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.values()
+ }
+ if self.auth.check_redaction(event, auth_events=auth_events):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
@@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events"
)
- if event.type == EventTypes.Create and context.current_state:
+ if event.type == EventTypes.Create and context.current_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
@@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id
)
- destinations = set()
- for k, s in context.current_state.items():
- try:
- if k[0] == EventTypes.Member:
- if s.content["membership"] == Membership.JOIN:
- destinations.add(get_domain_from_id(s.state_key))
- except SynapseError:
- logger.warn(
- "Failed to get destination from event %s", s.event_id
- )
+ destinations = yield self.get_joined_hosts_for_room_from_state(context)
@defer.inlineCallbacks
def _notify():
@@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations,
)
+
+ def get_joined_hosts_for_room_from_state(self, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_hosts_for_room_from_state(
+ state_group, context.current_state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=1, cache_context=True)
+ def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
+ cache_context):
+
+ # Don't bother getting state for people on the same HS
+ current_state = yield self.store.get_events([
+ e_id for key, e_id in current_state_ids.items()
+ if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
+ ])
+
+ destinations = set()
+ for e in current_state.itervalues():
+ try:
+ if e.type == EventTypes.Member:
+ if e.content["membership"] == Membership.JOIN:
+ destinations.add(get_domain_from_id(e.state_key))
+ except SynapseError:
+ logger.warn(
+ "Failed to get destination from event %s", e.event_id
+ )
+
+ defer.returnValue(destinations)
|