diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0db26fcfd7..f26e585623 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -52,7 +52,7 @@ class Auth(object):
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
- # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
+ # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,6 +63,17 @@ class Auth(object):
"user_id = ",
])
+ @defer.inlineCallbacks
+ def check_from_context(self, event, context, do_sig_check=True):
+ auth_events_ids = yield self.compute_auth_events(
+ event, context.current_state_ids, for_verification=True,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.values()
+ }
+ self.check(event, auth_events=auth_events, do_sig_check=False)
+
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
@@ -267,21 +278,15 @@ class Auth(object):
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
- curr_state = yield self.state.get_current_state(room_id)
-
- for event in curr_state.values():
- if event.type == EventTypes.Member:
- try:
- if get_domain_from_id(event.state_key) != host:
- continue
- except:
- logger.warn("state_key not user_id: %s", event.state_key)
- continue
+ with Measure(self.clock, "check_host_in_room"):
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- if event.content["membership"] == Membership.JOIN:
- defer.returnValue(True)
+ group, curr_state_ids = yield self.state.resolve_state_groups(
+ room_id, latest_event_ids
+ )
- defer.returnValue(False)
+ ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
+ defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
@@ -847,7 +852,7 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
- auth_ids = self.compute_auth_events(builder, context.current_state)
+ auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
@@ -855,30 +860,32 @@ class Auth(object):
builder.auth_events = auth_events_entries
- def compute_auth_events(self, event, current_state):
+ @defer.inlineCallbacks
+ def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
- return []
+ defer.returnValue([])
auth_ids = []
key = (EventTypes.PowerLevels, "", )
- power_level_event = current_state.get(key)
+ power_level_event_id = current_state_ids.get(key)
- if power_level_event:
- auth_ids.append(power_level_event.event_id)
+ if power_level_event_id:
+ auth_ids.append(power_level_event_id)
key = (EventTypes.JoinRules, "", )
- join_rule_event = current_state.get(key)
+ join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, )
- member_event = current_state.get(key)
+ member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", )
- create_event = current_state.get(key)
- if create_event:
- auth_ids.append(create_event.event_id)
+ create_event_id = current_state_ids.get(key)
+ if create_event_id:
+ auth_ids.append(create_event_id)
- if join_rule_event:
+ if join_rule_event_id:
+ join_rule_event = yield self.store.get_event(join_rule_event_id)
join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else:
@@ -887,15 +894,21 @@ class Auth(object):
if event.type == EventTypes.Member:
e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
- if join_rule_event:
- auth_ids.append(join_rule_event.event_id)
+ if join_rule_event_id:
+ auth_ids.append(join_rule_event_id)
if e_type == Membership.JOIN:
- if member_event and not is_public:
- auth_ids.append(member_event.event_id)
+ if member_event_id and not is_public:
+ auth_ids.append(member_event_id)
else:
- if member_event:
- auth_ids.append(member_event.event_id)
+ if member_event_id:
+ auth_ids.append(member_event_id)
+
+ if for_verification:
+ key = (EventTypes.Member, event.state_key, )
+ existing_event_id = current_state_ids.get(key)
+ if existing_event_id:
+ auth_ids.append(existing_event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
@@ -903,14 +916,15 @@ class Auth(object):
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
- third_party_invite = current_state.get(key)
- if third_party_invite:
- auth_ids.append(third_party_invite.event_id)
- elif member_event:
+ third_party_invite_id = current_state_ids.get(key)
+ if third_party_invite_id:
+ auth_ids.append(third_party_invite_id)
+ elif member_event_id:
+ member_event = yield self.store.get_event(member_event_id)
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
- return auth_ids
+ defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 13154b1723..bcb8f33a58 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -99,7 +99,7 @@ class EventBase(object):
return d
- def get(self, key, default):
+ def get(self, key, default=None):
return self._event_dict.get(key, default)
def get_internal_metadata_dict(self):
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 8a475417a6..c75afd02d8 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,9 +15,8 @@
class EventContext(object):
-
- def __init__(self, current_state=None):
- self.current_state = current_state
+ def __init__(self, current_state_ids=None):
+ self.current_state_ids = current_state_ids
self.state_group = None
self.rejected = False
self.push_actions = []
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 11081a0cd5..e58735294e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -65,33 +65,21 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
- def is_host_in_room(self, current_state):
- room_members = [
- (state_key, event.membership)
- for ((event_type, state_key), event) in current_state.items()
- if event_type == EventTypes.Member
- ]
- if len(room_members) == 0:
- # Have we just created the room, and is this about to be the very
- # first member event?
- create_event = current_state.get(("m.room.create", ""))
- if create_event:
- return True
- for (state_key, membership) in room_members:
- if (
- self.hs.is_mine_id(state_key)
- and membership == Membership.JOIN
- ):
- return True
- return False
-
@defer.inlineCallbacks
- def maybe_kick_guest_users(self, event, current_state):
+ def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
+ if context:
+ current_state = yield self.store.get_events(
+ context.current_state_ids.values()
+ )
+ current_state = current_state.values()
+ else:
+ current_state = yield self.store.get_current_state(event.room_id)
+ logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4344a2bd52..a7ea8fb98f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
+from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -217,17 +218,28 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
- prev_state = context.current_state.get((event.type, event.state_key))
- if not prev_state or prev_state.membership != Membership.JOIN:
- # Only fire user_joined_room if the user has acutally
- # joined the room. Don't bother if the user is just
- # changing their profile info.
+ # Only fire user_joined_room if the user has acutally
+ # joined the room. Don't bother if the user is just
+ # changing their profile info.
+ newly_joined = True
+ prev_state_id = context.current_state_ids.get(
+ (event.type, event.state_key)
+ )
+ if prev_state_id:
+ prev_state = yield self.store.get_event(
+ prev_state_id, allow_none=True,
+ )
+ if prev_state and prev_state.membership == Membership.JOIN:
+ newly_joined = False
+
+ if newly_joined:
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
+ @measure_func("_filter_events_for_server")
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
- event_to_state = yield self.store.get_state_for_events(
+ event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
@@ -235,6 +247,30 @@ class FederationHandler(BaseHandler):
)
)
+ # We only want to pull out member events that correspond to the
+ # server's domain.
+
+ def check_match(id):
+ try:
+ return server_name == get_domain_from_id(id)
+ except:
+ return False
+
+ event_map = yield self.store.get_events([
+ e_id for key_to_eid in event_to_state_ids.values()
+ for key, e_id in key_to_eid
+ if key[0] != EventTypes.Member or check_match(key[1])
+ ])
+
+ event_to_state = {
+ e_id: {
+ key: event_map[inner_e_id]
+ for key, inner_e_id in key_to_eid.items()
+ if inner_e_id in event_map
+ }
+ for e_id, key_to_eid in event_to_state_ids.items()
+ }
+
def redact_disallowed(event, state):
if not state:
return event
@@ -562,6 +598,18 @@ class FederationHandler(BaseHandler):
]))
states = dict(zip(event_ids, [s[1] for s in states]))
+ state_map = yield self.store.get_events(
+ [e_id for ids in states.values() for e_id in ids],
+ get_prev_content=False
+ )
+ states = {
+ key: {
+ k: state_map[e_id]
+ for k, e_id in state_dict.items()
+ if e_id in state_map
+ } for key, state_dict in states.items()
+ }
+
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
@@ -724,7 +772,7 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+ yield self.auth.check_from_context(event, context, do_sig_check=False)
defer.returnValue(event)
@@ -772,18 +820,11 @@ class FederationHandler(BaseHandler):
new_pdu = event
- destinations = set()
-
- for k, s in context.current_state.items():
- try:
- if k[0] == EventTypes.Member:
- if s.content["membership"] == Membership.JOIN:
- destinations.add(get_domain_from_id(s.state_key))
- except:
- logger.warn(
- "Failed to get destination from event %s", s.event_id
- )
-
+ message_handler = self.hs.get_handlers().message_handler
+ destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+ context
+ )
+ destinations = set(destinations)
destinations.discard(origin)
logger.debug(
@@ -794,13 +835,15 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations)
- state_ids = [e.event_id for e in context.current_state.values()]
+ state_ids = context.current_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
+ state = yield self.store.get_events(context.current_state_ids.values())
+
defer.returnValue({
- "state": context.current_state.values(),
+ "state": state.values(),
"auth_chain": auth_chain,
})
@@ -956,7 +999,7 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+ yield self.auth.check_from_context(event, context, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -1000,18 +1043,11 @@ class FederationHandler(BaseHandler):
new_pdu = event
- destinations = set()
-
- for k, s in context.current_state.items():
- try:
- if k[0] == EventTypes.Member:
- if s.content["membership"] == Membership.LEAVE:
- destinations.add(get_domain_from_id(s.state_key))
- except:
- logger.warn(
- "Failed to get destination from event %s", s.event_id
- )
-
+ message_handler = self.hs.get_handlers().message_handler
+ destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+ context
+ )
+ destinations = set(destinations)
destinations.discard(origin)
logger.debug(
@@ -1296,7 +1332,13 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
- auth_events = context.current_state
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, context.current_state_ids, for_verification=True,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.values()
+ }
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -1322,8 +1364,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess:
- full_context = yield self.store.get_current_state(room_id=event.room_id)
- yield self.maybe_kick_guest_users(event, full_context)
+ yield self.maybe_kick_guest_users(event)
defer.returnValue(context)
@@ -1494,7 +1535,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- context.current_state.update(auth_events)
+ context.current_state_ids.update({
+ k: a.event_id for k, a in auth_events.items()
+ })
context.state_group = None
if different_auth and not event.internal_metadata.is_outlier():
@@ -1516,8 +1559,8 @@ class FederationHandler(BaseHandler):
if do_resolution:
# 1. Get what we think is the auth chain.
- auth_ids = self.auth.compute_auth_events(
- event, context.current_state
+ auth_ids = yield self.auth.compute_auth_events(
+ event, context.current_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@@ -1573,7 +1616,9 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
- context.current_state.update(auth_events)
+ context.current_state_ids.update({
+ k: a.event_id for k, a in auth_events.items()
+ })
context.state_group = None
try:
@@ -1760,12 +1805,12 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, context.current_state)
+ yield self.auth.check_from_context(event, context)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, auth_events=context.current_state)
+ yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context)
else:
@@ -1791,11 +1836,11 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, auth_events=context.current_state)
+ self.auth.check_from_context(event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, auth_events=context.current_state)
+ yield self._check_signature(event, context)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
@@ -1809,7 +1854,12 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
- original_invite = context.current_state.get(key)
+ original_invite = None
+ original_invite_id = context.current_state_ids.get(key)
+ if original_invite_id:
+ original_invite = yield self.store.get_event(
+ original_invite_id, allow_none=True
+ )
if not original_invite:
logger.info(
"Could not find invite event for third_party_invite - "
@@ -1826,13 +1876,13 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context))
@defer.inlineCallbacks
- def _check_signature(self, event, auth_events):
+ def _check_signature(self, event, context):
"""
Checks that the signature in the event is consistent with its invite.
Args:
event (Event): The m.room.member event to check
- auth_events (dict<(event type, state_key), event>):
+ context (EventContext):
Raises:
AuthError: if signature didn't match any keys, or key has been
@@ -1843,10 +1893,14 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- invite_event = auth_events.get(
+ invite_event_id = context.current_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)
+ invite_event = None
+ if invite_event_id:
+ invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+
if not invite_event:
raise AuthError(403, "Could not find invite")
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4c3cd9d12e..e2f4387f60 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
- prev_state = self.deduplicate_state_event(event, context)
+ prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
@@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
+ @defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
@@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_event = context.current_state.get((event.type, event.state_key))
+ prev_event_id = context.current_state_ids.get((event.type, event.state_key))
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ if not prev_event:
+ return
+
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
- return prev_event
- return None
+ defer.returnValue(prev_event)
+ return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
@@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
logger.debug(
"Created event %s with current state: %s",
- event.event_id, context.current_state,
+ event.event_id, context.current_state_ids,
)
defer.returnValue(
@@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
self.ratelimit(requester)
try:
- self.auth.check(event, auth_events=context.current_state)
+ yield self.auth.check_from_context(event, context)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
- yield self.maybe_kick_guest_users(event, context.current_state.values())
+ yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
@@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
e.sender == event.sender
)
+ state_to_include_ids = [
+ e_id
+ for k, e_id in context.current_state_ids.items()
+ if k[0] in self.hs.config.room_invite_state_types
+ or k[0] == EventTypes.Member and k[1] == event.sender
+ ]
+
+ state_to_include = yield self.store.get_events(state_to_include_ids)
+
event.unsigned["invite_room_state"] = [
{
"type": e.type,
@@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
"content": e.content,
"sender": e.sender,
}
- for k, e in context.current_state.items()
- if e.type in self.hs.config.room_invite_state_types
- or is_inviter_member_event(e)
+ for e in state_to_include.values()
]
invitee = UserID.from_string(event.state_key)
@@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Redaction:
- if self.auth.check_redaction(event, auth_events=context.current_state):
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, context.current_state_ids, for_verification=True,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.values()
+ }
+ if self.auth.check_redaction(event, auth_events=auth_events):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
@@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events"
)
- if event.type == EventTypes.Create and context.current_state:
+ if event.type == EventTypes.Create and context.current_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
@@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id
)
- destinations = set()
- for k, s in context.current_state.items():
- try:
- if k[0] == EventTypes.Member:
- if s.content["membership"] == Membership.JOIN:
- destinations.add(get_domain_from_id(s.state_key))
- except SynapseError:
- logger.warn(
- "Failed to get destination from event %s", s.event_id
- )
+ destinations = yield self.get_joined_hosts_for_room_from_state(context)
@defer.inlineCallbacks
def _notify():
@@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations,
)
+
+ def get_joined_hosts_for_room_from_state(self, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_hosts_for_room_from_state(
+ state_group, context.current_state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=1, cache_context=True)
+ def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
+ cache_context):
+
+ # Don't bother getting state for people on the same HS
+ current_state = yield self.store.get_events([
+ e_id for key, e_id in current_state_ids.items()
+ if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
+ ])
+
+ destinations = set()
+ for e in current_state.itervalues():
+ try:
+ if e.type == EventTypes.Member:
+ if e.content["membership"] == Membership.JOIN:
+ destinations.add(get_domain_from_id(e.state_key))
+ except SynapseError:
+ logger.warn(
+ "Failed to get destination from event %s", e.event_id
+ )
+
+ defer.returnValue(destinations)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8b17632fdc..dd4b90ee24 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
- prev_member_event = context.current_state.get(
+ prev_member_event_id = context.current_state_ids.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
- if not prev_member_event or prev_member_event.membership != Membership.JOIN:
- # Only fire user_joined_room if the user has acutally joined the
- # room. Don't bother if the user is just changing their profile
- # info.
+ # Only fire user_joined_room if the user has acutally joined the
+ # room. Don't bother if the user is just changing their profile
+ # info.
+ newly_joined = True
+ if prev_member_event_id:
+ prev_member_event = yield self.store.get_event(prev_member_event_id)
+ newly_joined = prev_member_event.membership != Membership.JOIN
+ if newly_joined:
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
- if prev_member_event and prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target, room_id)
+ if prev_member_event_id:
+ prev_member_event = yield self.store.get_event(prev_member_event_id)
+ if prev_member_event.membership == Membership.JOIN:
+ user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
@@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- current_state = yield self.state_handler.get_current_state(
+ current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
)
- old_state = current_state.get((EventTypes.Member, target.to_string()))
- old_membership = old_state.content.get("membership") if old_state else None
- if action == "unban" and old_membership != "ban":
- raise SynapseError(
- 403,
- "Cannot unban user who was not banned (membership=%s)" % old_membership,
- errcode=Codes.BAD_STATE
- )
- if old_membership == "ban" and action != "unban":
- raise SynapseError(
- 403,
- "Cannot %s user who was banned" % (action,),
- errcode=Codes.BAD_STATE
- )
+ old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+ if old_state_id:
+ old_state = yield self.store.get_event(old_state_id, allow_none=True)
+ old_membership = old_state.content.get("membership") if old_state else None
+ if action == "unban" and old_membership != "ban":
+ raise SynapseError(
+ 403,
+ "Cannot unban user who was not banned"
+ " (membership=%s)" % old_membership,
+ errcode=Codes.BAD_STATE
+ )
+ if old_membership == "ban" and action != "unban":
+ raise SynapseError(
+ 403,
+ "Cannot %s user who was banned" % (action,),
+ errcode=Codes.BAD_STATE
+ )
- is_host_in_room = self.is_host_in_room(current_state)
+ is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
- if requester.is_guest and not self._can_guest_join(current_state):
+ if requester.is_guest and not self._can_guest_join(current_state_ids):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
@@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler):
requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler
- prev_event = message_handler.deduplicate_state_event(event, context)
+ prev_event = yield message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
- if requester.is_guest and not self._can_guest_join(context.current_state):
- # This should be an auth check, but guests are a local concept,
- # so don't really fit into the general auth process.
- raise AuthError(403, "Guest access not allowed")
+ if requester.is_guest:
+ guest_can_join = yield self._can_guest_join(context.current_state_ids)
+ if not guest_can_join:
+ # This should be an auth check, but guests are a local concept,
+ # so don't really fit into the general auth process.
+ raise AuthError(403, "Guest access not allowed")
yield message_handler.handle_new_client_event(
requester,
@@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
- prev_member_event = context.current_state.get(
- (EventTypes.Member, target_user.to_string()),
+ prev_member_event_id = context.current_state_ids.get(
+ (EventTypes.Member, event.state_key),
None
)
if event.membership == Membership.JOIN:
- if not prev_member_event or prev_member_event.membership != Membership.JOIN:
- # Only fire user_joined_room if the user has acutally joined the
- # room. Don't bother if the user is just changing their profile
- # info.
+ # Only fire user_joined_room if the user has acutally joined the
+ # room. Don't bother if the user is just changing their profile
+ # info.
+ newly_joined = True
+ if prev_member_event_id:
+ prev_member_event = yield self.store.get_event(prev_member_event_id)
+ newly_joined = prev_member_event.membership != Membership.JOIN
+ if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
- if prev_member_event and prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target_user, room_id)
+ if prev_member_event_id:
+ prev_member_event = yield self.store.get_event(prev_member_event_id)
+ if prev_member_event.membership == Membership.JOIN:
+ user_left_room(self.distributor, target_user, room_id)
- def _can_guest_join(self, current_state):
+ @defer.inlineCallbacks
+ def _can_guest_join(self, current_state_ids):
"""
Returns whether a guest can join a room based on its current state.
"""
- guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
- return (
+ guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+ if not guest_access_id:
+ defer.returnValue(False)
+
+ guest_access = yield self.store.get_event(guest_access_id)
+
+ defer.returnValue(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
if membership:
yield self.store.forget(user_id, room_id)
+
+ @defer.inlineCallbacks
+ def _is_host_in_room(self, current_state_ids):
+ # Have we just created the room, and is this about to be the very
+ # first member event?
+ create_event_id = current_state_ids.get(("m.room.create", ""))
+ if len(current_state_ids) == 1 and create_event_id:
+ defer.returnValue(self.hs.is_mine_id(create_event_id))
+
+ for (etype, state_key), event_id in current_state_ids.items():
+ if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
+ continue
+
+ event = yield self.store.get_event(event_id, allow_none=True)
+ if not event:
+ continue
+
+ if event.membership == Membership.JOIN:
+ defer.returnValue(True)
+
+ defer.returnValue(False)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c8dfd02e7b..5cd009a1c8 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -355,11 +355,11 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- state = yield self.store.get_state_for_event(event.event_id)
+ state_ids = yield self.store.get_state_ids_for_event(event.event_id)
if event.is_state():
- state = state.copy()
- state[(event.type, event.state_key)] = event
- defer.returnValue(state)
+ state_ids = state_ids.copy()
+ state_ids[(event.type, event.state_key)] = event.event_id
+ defer.returnValue(state_ids)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
@@ -412,57 +412,61 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"):
if full_state:
if batch:
- current_state = yield self.store.get_state_for_event(
+ current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id
)
- state = yield self.store.get_state_for_event(
+ state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id
)
else:
- current_state = yield self.get_state_at(
+ current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token
)
- state = current_state
+ state_ids = current_state_ids
timeline_state = {
- (event.type, event.state_key): event
+ (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
- state = _calculate_state(
+ state_ids = _calculate_state(
timeline_contains=timeline_state,
- timeline_start=state,
+ timeline_start=state_ids,
previous={},
- current=current_state,
+ current=current_state_ids,
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
- current_state = yield self.store.get_state_for_event(
+ current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id
)
- state_at_timeline_start = yield self.store.get_state_for_event(
+ state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id
)
timeline_state = {
- (event.type, event.state_key): event
+ (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
- state = _calculate_state(
+ state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
- current=current_state,
+ current=current_state_ids,
)
else:
- state = {}
+ state_ids = {}
+
+ state = {}
+ if state_ids:
+ state = yield self.store.get_events(state_ids.values())
defer.returnValue({
(e.type, e.state_key): e
@@ -766,8 +770,13 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join:
- old_state = yield self.get_state_at(room_id, since_token)
- old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+ old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
+ old_mem_ev = None
+ if old_mem_ev_id:
+ old_mem_ev = yield self.store.get_event(
+ old_mem_ev_id, allow_none=True
+ )
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
@@ -1059,27 +1068,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns:
dict
"""
- event_id_to_state = {
- e.event_id: e
- for e in itertools.chain(
- timeline_contains.values(),
- previous.values(),
- timeline_start.values(),
- current.values(),
+ event_id_to_key = {
+ e: key
+ for key, e in itertools.chain(
+ timeline_contains.items(),
+ previous.items(),
+ timeline_start.items(),
+ current.items(),
)
}
- c_ids = set(e.event_id for e in current.values())
- tc_ids = set(e.event_id for e in timeline_contains.values())
- p_ids = set(e.event_id for e in previous.values())
- ts_ids = set(e.event_id for e in timeline_start.values())
+ c_ids = set(e for e in current.values())
+ tc_ids = set(e for e in timeline_contains.values())
+ p_ids = set(e for e in previous.values())
+ ts_ids = set(e for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
- evs = (event_id_to_state[e] for e in state_ids)
return {
- (e.type, e.state_key): e
- for e in evs
+ event_id_to_key[e]: e for e in state_ids
}
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index ed2ccc4dfb..3f75d3f921 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,12 +40,12 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store, context.state_group, context.current_state
+ event, self.hs, self.store, context
)
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
- event, context.current_state
+ event, context
)
context.push_actions = [
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 004eded61f..8d49beaec5 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,8 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-from synapse.api.constants import EventTypes, Membership
-from synapse.visibility import filter_events_for_clients
+from synapse.api.constants import EventTypes
+from synapse.visibility import filter_events_for_clients_context
logger = logging.getLogger(__name__)
@@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, state_group, current_state):
+def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room(
- event.room_id, state_group, current_state
+ event.room_id, context
)
# if this event is an invite event, we may need to run rules for the user
@@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
self.store = store
@defer.inlineCallbacks
- def action_for_event_by_user(self, event, current_state):
+ def action_for_event_by_user(self, event, context):
actions_by_user = {}
# None of these users can be peeking since this list of users comes
@@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
(u, False) for u in self.rules_by_user.keys()
]
- filtered_by_user = yield filter_events_for_clients(
- self.store, user_tuples, [event], {event.event_id: current_state}
+ filtered_by_user = yield filter_events_for_clients_context(
+ self.store, user_tuples, [event], {event.event_id: context}
)
- room_members = set(
- e.state_key for e in current_state.values()
- if e.type == EventTypes.Member and e.membership == Membership.JOIN
+ room_members = yield self.store.get_joined_users_from_context(
+ event.room_id, context,
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
condition_cache = {}
- display_names = {}
- for ev in current_state.values():
- nm = ev.content.get("displayname", None)
- if nm and ev.type == EventTypes.Member:
- display_names[ev.state_key] = nm
-
for uid, rules in self.rules_by_user.items():
- display_name = display_names.get(uid, None)
+ display_name = None
+ member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
+ if member_ev_id:
+ member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
+ if member_ev:
+ display_name = member_ev.content.get("displayname", None)
filtered = filtered_by_user[uid]
if len(filtered) == 0:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index feedb075e2..c0f8176e3d 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -245,7 +245,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event(
- self.state_handler, event, self.user_id
+ self.store, self.state_handler, event, self.user_id
)
d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 1028731bc9..845ddd43da 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -22,7 +22,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute
-from synapse.util.presentable_names import (
+from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events
)
from synapse.types import UserID
@@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks
def _fetch_room_state(room_id):
- room_state = yield self.state_handler.get_current_state(room_id)
+ room_state = yield self.state_handler.get_current_state_ids(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
@@ -160,7 +160,8 @@ class Mailer(object):
rooms.append(roomvars)
reason['room_name'] = calculate_room_name(
- state_by_room[reason['room_id']], user_id, fallback_to_members=True
+ self.store, state_by_room[reason['room_id']], user_id,
+ fallback_to_members=True
)
summary_text = self.make_summary_text(
@@ -203,12 +204,15 @@ class Mailer(object):
)
@defer.inlineCallbacks
- def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
- my_member_event = room_state[("m.room.member", user_id)]
+ def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
+ my_member_event_id = room_state_ids[("m.room.member", user_id)]
+ my_member_event = yield self.store.get_event(my_member_event_id)
is_invite = my_member_event.content["membership"] == "invite"
+ room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
+
room_vars = {
- "title": calculate_room_name(room_state, user_id),
+ "title": room_name,
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
"invite": is_invite,
@@ -218,7 +222,7 @@ class Mailer(object):
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
- n, user_id, notif_events[n['event_id']], room_state
+ n, user_id, notif_events[n['event_id']], room_state_ids
)
# merge overlapping notifs together.
@@ -243,7 +247,7 @@ class Mailer(object):
defer.returnValue(room_vars)
@defer.inlineCallbacks
- def get_notif_vars(self, notif, user_id, notif_event, room_state):
+ def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
@@ -261,17 +265,19 @@ class Mailer(object):
the_events.append(notif_event)
for event in the_events:
- messagevars = self.get_message_vars(notif, event, room_state)
+ messagevars = yield self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None:
ret['messages'].append(messagevars)
defer.returnValue(ret)
- def get_message_vars(self, notif, event, room_state):
+ @defer.inlineCallbacks
+ def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message:
- return None
+ return
- sender_state_event = room_state[("m.room.member", event.sender)]
+ sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
+ sender_state_event = yield self.store.get_event(sender_state_event_id)
sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url")
@@ -299,7 +305,7 @@ class Mailer(object):
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
- return ret
+ defer.returnValue(ret)
def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format")
diff --git a/synapse/util/presentable_names.py b/synapse/push/presentable_names.py
index f68676e9e7..f90b789c05 100644
--- a/synapse/util/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
import re
import logging
@@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room"
-def calculate_room_name(room_state, user_id, fallback_to_members=True,
+@defer.inlineCallbacks
+def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
fallback_to_single_member=True):
"""
Works out a user-facing name for the given room as per Matrix
@@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
(string or None) A human readable name for the room.
"""
# does it have a name?
- if ("m.room.name", "") in room_state:
- m_room_name = room_state[("m.room.name", "")]
- if m_room_name.content and m_room_name.content["name"]:
- return m_room_name.content["name"]
+ if ("m.room.name", "") in room_state_ids:
+ m_room_name = yield store.get_event(
+ room_state_ids[("m.room.name", "")], allow_none=True
+ )
+ if m_room_name and m_room_name.content and m_room_name.content["name"]:
+ defer.returnValue(m_room_name.content["name"])
# does it have a canonical alias?
- if ("m.room.canonical_alias", "") in room_state:
- canon_alias = room_state[("m.room.canonical_alias", "")]
+ if ("m.room.canonical_alias", "") in room_state_ids:
+ canon_alias = yield store.get_event(
+ room_state_ids[("m.room.canonical_alias", "")], allow_none=True
+ )
if (
- canon_alias.content and canon_alias.content["alias"] and
+ canon_alias and canon_alias.content and canon_alias.content["alias"] and
_looks_like_an_alias(canon_alias.content["alias"])
):
- return canon_alias.content["alias"]
+ defer.returnValue(canon_alias.content["alias"])
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
- room_state_bytype = _state_as_two_level_dict(room_state)
+ room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
# right then, any aliases at all?
- if "m.room.aliases" in room_state_bytype:
- m_room_aliases = room_state_bytype["m.room.aliases"]
- if len(m_room_aliases.values()) > 0:
- first_alias_event = m_room_aliases.values()[0]
- if first_alias_event.content and first_alias_event.content["aliases"]:
- the_aliases = first_alias_event.content["aliases"]
+ if "m.room.aliases" in room_state_bytype_ids:
+ m_room_aliases = room_state_bytype_ids["m.room.aliases"]
+ for alias_id in m_room_aliases.values():
+ alias_event = yield store.get_event(
+ alias_id, allow_none=True
+ )
+ if alias_event and alias_event.content and alias_event.get("aliases"):
+ the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
- return the_aliases[0]
+ defer.returnValue(the_aliases[0])
if not fallback_to_members:
- return None
+ defer.returnValue(None)
my_member_event = None
- if ("m.room.member", user_id) in room_state:
- my_member_event = room_state[("m.room.member", user_id)]
+ if ("m.room.member", user_id) in room_state_ids:
+ my_member_event = yield store.get_event(
+ room_state_ids[("m.room.member", user_id)], allow_none=True
+ )
if (
my_member_event is not None and
my_member_event.content['membership'] == "invite"
):
- if ("m.room.member", my_member_event.sender) in room_state:
- inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
- if fallback_to_single_member:
- return "Invite from %s" % (name_from_member_event(inviter_member_event),)
- else:
- return None
+ if ("m.room.member", my_member_event.sender) in room_state_ids:
+ inviter_member_event = yield store.get_event(
+ room_state_ids[("m.room.member", my_member_event.sender)],
+ allow_none=True,
+ )
+ if inviter_member_event:
+ if fallback_to_single_member:
+ defer.returnValue(
+ "Invite from %s" % (
+ name_from_member_event(inviter_member_event),
+ )
+ )
+ else:
+ return
else:
- return "Room Invite"
+ defer.returnValue("Room Invite")
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
- if "m.room.member" in room_state_bytype:
+ if "m.room.member" in room_state_bytype_ids:
+ member_events = yield store.get_events(
+ room_state_bytype_ids["m.room.member"].values()
+ )
all_members = [
- ev for ev in room_state_bytype["m.room.member"].values()
+ ev for ev in member_events.values()
if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
]
# Sort the member events oldest-first so the we name people in the
@@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
# self-chat, peeked room with 1 participant,
# or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id:
- if "m.room.third_party_invite" in room_state_bytype:
+ if "m.room.third_party_invite" in room_state_bytype_ids:
third_party_invites = (
- room_state_bytype["m.room.third_party_invite"].values()
+ room_state_bytype_ids["m.room.third_party_invite"].values()
)
if len(third_party_invites) > 0:
@@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
# return "Inviting %s" % (
# descriptor_from_member_events(third_party_invites)
# )
- return "Inviting email address"
+ defer.returnValue("Inviting email address")
else:
- return ALL_ALONE
+ defer.returnValue(ALL_ALONE)
else:
- return name_from_member_event(all_members[0])
+ defer.returnValue(name_from_member_event(all_members[0]))
else:
- return ALL_ALONE
+ defer.returnValue(ALL_ALONE)
elif len(other_members) == 1 and not fallback_to_single_member:
- return None
+ return
else:
- return descriptor_from_member_events(other_members)
+ defer.returnValue(descriptor_from_member_events(other_members))
def descriptor_from_member_events(member_events):
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index becb8ef1ae..b47bf1f92b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -14,7 +14,7 @@
# limitations under the License.
from twisted.internet import defer
-from synapse.util.presentable_names import (
+from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event
)
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@@ -49,21 +49,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks
-def get_context_for_event(state_handler, ev, user_id):
+def get_context_for_event(store, state_handler, ev, user_id):
ctx = {}
- room_state = yield state_handler.get_current_state(ev.room_id)
+ room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
- name = calculate_room_name(
- room_state, user_id, fallback_to_single_member=False
+ name = yield calculate_room_name(
+ store, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
ctx['name'] = name
- sender_state_event = room_state[("m.room.member", ev.sender)]
+ sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
+ sender_state_event = yield store.get_event(sender_state_event_id)
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index f4f31f2d27..65e982a0ce 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -120,10 +120,15 @@ class SlavedEventStore(BaseSlavedStore):
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
+ get_state_groups_ids = DataStore.get_state_groups_ids.__func__
+ get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
+ get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
+ is_host_joined = DataStore.is_host_joined.__func__
+ _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
diff --git a/synapse/state.py b/synapse/state.py
index ef1bc470be..78461215ca 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -93,8 +93,30 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- res = yield self.resolve_state_groups(room_id, latest_event_ids)
- state = res[1]
+ _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+
+ if event_type:
+ event_id = state.get((event_type, state_key))
+ event = None
+ if event_id:
+ event = yield self.store.get_event(event_id, allow_none=True)
+ defer.returnValue(event)
+ return
+
+ state_map = yield self.store.get_events(state.values(), get_prev_content=False)
+ state = {
+ key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+ }
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def get_current_state_ids(self, room_id, event_type=None, state_key="",
+ latest_event_ids=None):
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+
+ _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
if event_type:
defer.returnValue(state.get((event_type, state_key)))
@@ -123,27 +145,27 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
- context.current_state = {
- (s.type, s.state_key): s for s in old_state
+ context.current_state_ids = {
+ (s.type, s.state_key): s.event_id for s in old_state
}
else:
- context.current_state = {}
+ context.current_state_ids = {}
context.prev_state_events = []
context.state_group = None
defer.returnValue(context)
if old_state:
- context.current_state = {
- (s.type, s.state_key): s for s in old_state
+ context.current_state_ids = {
+ (s.type, s.state_key): s.event_id for s in old_state
}
context.state_group = None
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state:
- replaces = context.current_state[key]
- if replaces.event_id != event.event_id: # Paranoia check
- event.unsigned["replaces_state"] = replaces.event_id
+ if key in context.current_state_ids:
+ replaces = context.current_state_ids[key]
+ if replaces != event.event_id: # Paranoia check
+ event.unsigned["replaces_state"] = replaces
context.prev_state_events = []
defer.returnValue(context)
@@ -159,18 +181,18 @@ class StateHandler(object):
event.room_id, [e for e, _ in event.prev_events],
)
- group, curr_state, prev_state = ret
+ group, curr_state = ret
- context.current_state = curr_state
+ context.current_state_ids = curr_state
context.state_group = group if not event.is_state() else None
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state:
- replaces = context.current_state[key]
- event.unsigned["replaces_state"] = replaces.event_id
+ if key in context.current_state_ids:
+ replaces = context.current_state_ids[key]
+ event.unsigned["replaces_state"] = replaces
- context.prev_state_events = prev_state
+ context.prev_state_events = []
defer.returnValue(context)
@defer.inlineCallbacks
@@ -187,72 +209,83 @@ class StateHandler(object):
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
- state_groups = yield self.store.get_state_groups(
+ state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids
)
logger.debug(
"resolve_state_groups state_groups %s",
- state_groups.keys()
+ state_groups_ids.keys()
)
- group_names = frozenset(state_groups.keys())
+ group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1:
- name, state_list = state_groups.items().pop()
- state = {
- (e.type, e.state_key): e
- for e in state_list
- }
- prev_state = state.get((event_type, state_key), None)
- if prev_state:
- prev_state = prev_state.event_id
- prev_states = [prev_state]
- else:
- prev_states = []
+ name, state_list = state_groups_ids.items().pop()
- defer.returnValue((name, state, prev_states))
+ defer.returnValue((name, state_list,))
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
cache.ts = self.clock.time_msec()
- event_dict = yield self.store.get_events(cache.state.values())
- state = {(e.type, e.state_key): e for e in event_dict.values()}
-
- prev_state = state.get((event_type, state_key), None)
- if prev_state:
- prev_state = prev_state.event_id
- prev_states = [prev_state]
- else:
- prev_states = []
defer.returnValue(
- (cache.state_group, state, prev_states)
+ (cache.state_group, cache.state,)
)
- logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
-
- new_state, prev_states = self._resolve_events(
- state_groups.values(), event_type, state_key
+ logger.info(
+ "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
+ state = {}
+ for st in state_groups_ids.values():
+ for key, e_id in st.items():
+ state.setdefault(key, set()).add(e_id)
+
+ conflicted_state = {
+ k: list(v)
+ for k, v in state.items()
+ if len(v) > 1
+ }
+
+ if conflicted_state:
+ logger.info("Resolving conflicted state for %r", room_id)
+ state_map = yield self.store.get_events(
+ [e_id for st in state_groups_ids.values() for e_id in st.values()],
+ get_prev_content=False
+ )
+ state_sets = [
+ [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
+ for st in state_groups_ids.values()
+ ]
+ new_state, _ = self._resolve_events(
+ state_sets, event_type, state_key
+ )
+ new_state = {
+ key: e.event_id for key, e in new_state.items()
+ }
+ else:
+ new_state = {
+ key: e_ids.pop() for key, e_ids in state.items()
+ }
+
state_group = None
- new_state_event_ids = frozenset(e.event_id for e in new_state.values())
- for sg, events in state_groups.items():
- if new_state_event_ids == frozenset(e.event_id for e in events):
+ new_state_event_ids = frozenset(new_state.values())
+ for sg, events in state_groups_ids.items():
+ if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
if self._state_cache is not None:
cache = _StateCacheEntry(
- state={key: event.event_id for key, event in new_state.items()},
+ state=new_state,
state_group=state_group,
ts=self.clock.time_msec()
)
self._state_cache[group_names] = cache
- defer.returnValue((state_group, new_state, prev_states))
+ defer.returnValue((state_group, new_state,))
def resolve_events(self, state_sets, event):
logger.info(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 78334a98cf..7e6ec411cd 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
- def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+ def bulk_get_push_rules_for_room(self, room_id, context):
+ state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+ return self._bulk_get_push_rules_for_room(
+ room_id, state_group, context.current_state_ids
+ )
@cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
+ def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
@@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore):
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
+ local_user_member_ids = [
+ e_id for (etype, state_key), e_id in current_state_ids.iteritems()
+ if etype == EventTypes.Member and self.hs.is_mine_id(state_key)
+ ]
+
+ local_member_events = yield self._get_events(local_user_member_ids)
+
local_users_in_room = set(
- e.state_key for e in current_state.values()
- if e.type == EventTypes.Member and e.membership == Membership.JOIN
- and self.hs.is_mine_id(e.state_key)
+ member_event.state_key for member_event in local_member_events
+ if member_event.membership == Membership.JOIN
)
# users in the room who have pushers need to get push rules run because
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index a422ddf633..5f15200c20 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -20,7 +20,7 @@ from collections import namedtuple
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id
import logging
@@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id):
- """Returns whether user_id has elected to discard history for room_id at event_id.
+ """Returns whether user_id has elected to discard history for room_id at
+ event_id.
event_id must be a membership event."""
def f(txn):
@@ -358,3 +359,80 @@ class RoomMemberStore(SQLBaseStore):
},
desc="who_forgot"
)
+
+ def get_joined_users_from_context(self, room_id, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ room_id, state_group, context.current_state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+ cache_context):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ member_event_ids = [
+ e_id
+ for key, e_id in current_state_ids.iteritems()
+ if key[0] == EventTypes.Member
+ ]
+
+ rows = yield self._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=['user_id'],
+ keyvalues={
+ "membership": Membership.JOIN,
+ },
+ batch_size=1000,
+ desc="_get_joined_users_from_context",
+ )
+
+ defer.returnValue(set(row["user_id"] for row in rows))
+
+ def is_host_joined(self, room_id, host, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._is_host_joined(
+ room_id, host, state_group, state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=3)
+ def _is_host_joined(self, room_id, host, state_group, current_state_ids):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ for (etype, state_key), event_id in current_state_ids.items():
+ if etype == EventTypes.Member:
+ try:
+ if get_domain_from_id(state_key) != host:
+ continue
+ except:
+ logger.warn("state_key not user_id: %s", state_key)
+ continue
+
+ event = yield self.get_event(event_id, allow_none=True)
+ if event and event.content["membership"] == Membership.JOIN:
+ defer.returnValue(True)
+
+ defer.returnValue(False)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0e8fa93e1f..b1d461fef5 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
"""
@defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
- """ Get the state groups for the given list of event_ids
-
- The return value is a dict mapping group names to lists of events.
- """
+ def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
defer.returnValue({})
@@ -59,9 +55,32 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
+ defer.returnValue(group_to_state)
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, room_id, event_ids):
+ """ Get the state groups for the given list of event_ids
+
+ The return value is a dict mapping group names to lists of events.
+ """
+ if not event_ids:
+ defer.returnValue({})
+
+ group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = yield self.get_events(
+ [
+ ev_id for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
+ ],
+ get_prev_content=False
+ )
+
defer.returnValue({
- group: state_map.values()
- for group, state_map in group_to_state.items()
+ group: [
+ state_event_map[v] for v in event_id_map.values() if v in state_event_map
+ ]
+ for group, event_id_map in group_to_ids.items()
})
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
@@ -70,17 +89,17 @@ class StateStore(SQLBaseStore):
if event.internal_metadata.is_outlier():
continue
- if context.current_state is None:
+ if context.current_state_ids is None:
continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group
continue
- state_events = dict(context.current_state)
+ state_event_ids = dict(context.current_state_ids)
if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ state_event_ids[(event.type, event.state_key)] = event.event_id
state_group = context.new_state_group_id
@@ -100,12 +119,12 @@ class StateStore(SQLBaseStore):
values=[
{
"state_group": state_group,
- "room_id": state.room_id,
- "type": state.type,
- "state_key": state.state_key,
- "event_id": state.event_id,
+ "room_id": event.room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
}
- for state in state_events.values()
+ for key, state_id in state_event_ids.items()
],
)
state_groups[event.event_id] = state_group
@@ -248,6 +267,31 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
+ state_event_map = yield self.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ get_prev_content=False
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in group_to_state[group].items()
+ if v in state_event_map
+ }
+ for event_id, group in event_to_groups.items()
+ }
+
+ defer.returnValue({event: event_to_state[event] for event in event_ids})
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_events(self, event_ids, types):
+ event_to_groups = yield self._get_state_group_for_events(
+ event_ids,
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = yield self._get_state_for_groups(groups, types)
+
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
@@ -272,6 +316,23 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
+ @defer.inlineCallbacks
+ def get_state_ids_for_event(self, event_id, types=None):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ types(list[(str, str)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. May be None, which
+ matches any key
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_ids_for_events([event_id], types)
+ defer.returnValue(state_map[event_id])
+
@cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
@@ -428,20 +489,13 @@ class StateStore(SQLBaseStore):
full=(types is None),
)
- state_events = yield self._get_events(
- [ev_id for sd in results.values() for ev_id in sd.values()],
- get_prev_content=False
- )
-
- state_events = {e.event_id: e for e in state_events}
-
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = {
- key: state_events[event_id]
+ key: event_id
for key, event_id in state_dict.items()
- if event_id and event_id in state_events
+ if event_id
}
defer.returnValue(results)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index cc12c0a23d..199b16d827 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -181,6 +181,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
@defer.inlineCallbacks
+def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
+ user_ids = set(u[0] for u in user_tuples)
+ event_id_to_state = {}
+ for event_id, context in event_id_to_context.items():
+ state = yield store.get_events([
+ e_id
+ for key, e_id in context.current_state_ids.iteritems()
+ if key == (EventTypes.RoomHistoryVisibility, "")
+ or (key[0] == EventTypes.Member and key[1] in user_ids)
+ ])
+ event_id_to_state[event_id] = state
+
+ res = yield filter_events_for_clients(
+ store, user_tuples, events, event_id_to_state
+ )
+ defer.returnValue(res)
+
+
+@defer.inlineCallbacks
def filter_events_for_client(store, user_id, events, is_peeking=False):
"""
Check which events a user is allowed to see
|