diff options
author | Mark Haines <mark.haines@matrix.org> | 2016-08-26 14:35:31 +0100 |
---|---|---|
committer | Mark Haines <mark.haines@matrix.org> | 2016-08-26 14:35:31 +0100 |
commit | 4bbef62124f0fb249e314e94a1b9c15204c8daa9 (patch) | |
tree | cc2e7458da160915f7224c0487f185e2433eaadb /synapse/storage | |
parent | More 0_0 in tests (diff) | |
parent | Merge pull request #1048 from matrix-org/erikj/fix_mail_name (diff) | |
download | synapse-4bbef62124f0fb249e314e94a1b9c15204c8daa9.tar.xz |
Merge remote-tracking branch 'origin/develop' into markjh/direct_to_device
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/push_rule.py | 21 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 82 | ||||
-rw-r--r-- | synapse/storage/state.py | 102 |
3 files changed, 173 insertions, 32 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 78334a98cf..7e6ec411cd 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) - def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): + def bulk_get_push_rules_for_room(self, room_id, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) + return self._bulk_get_push_rules_for_room( + room_id, state_group, context.current_state_ids + ) @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, cache_context): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's @@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore): # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. + local_user_member_ids = [ + e_id for (etype, state_key), e_id in current_state_ids.iteritems() + if etype == EventTypes.Member and self.hs.is_mine_id(state_key) + ] + + local_member_events = yield self._get_events(local_user_member_ids) + local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and self.hs.is_mine_id(e.state_key) + member_event.state_key for member_event in local_member_events + if member_event.membership == Membership.JOIN ) # users in the room who have pushers need to get push rules run because diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index a422ddf633..5f15200c20 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,7 +20,7 @@ from collections import namedtuple from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.api.constants import Membership +from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging @@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3) def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at event_id. + """Returns whether user_id has elected to discard history for room_id at + event_id. event_id must be a membership event.""" def f(txn): @@ -358,3 +359,80 @@ class RoomMemberStore(SQLBaseStore): }, desc="who_forgot" ) + + def get_joined_users_from_context(self, room_id, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, context.current_state_ids + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=['user_id'], + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=1000, + desc="_get_joined_users_from_context", + ) + + defer.returnValue(set(row["user_id"] for row in rows)) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._is_host_joined( + room_id, host, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0e8fa93e1f..b1d461fef5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -44,11 +44,7 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - The return value is a dict mapping group names to lists of events. - """ + def get_state_groups_ids(self, room_id, event_ids): if not event_ids: defer.returnValue({}) @@ -59,9 +55,32 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups) + defer.returnValue(group_to_state) + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + The return value is a dict mapping group names to lists of events. + """ + if not event_ids: + defer.returnValue({}) + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False + ) + defer.returnValue({ - group: state_map.values() - for group, state_map in group_to_state.items() + group: [ + state_event_map[v] for v in event_id_map.values() if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() }) def _store_mult_state_groups_txn(self, txn, events_and_contexts): @@ -70,17 +89,17 @@ class StateStore(SQLBaseStore): if event.internal_metadata.is_outlier(): continue - if context.current_state is None: + if context.current_state_ids is None: continue if context.state_group is not None: state_groups[event.event_id] = context.state_group continue - state_events = dict(context.current_state) + state_event_ids = dict(context.current_state_ids) if event.is_state(): - state_events[(event.type, event.state_key)] = event + state_event_ids[(event.type, event.state_key)] = event.event_id state_group = context.new_state_group_id @@ -100,12 +119,12 @@ class StateStore(SQLBaseStore): values=[ { "state_group": state_group, - "room_id": state.room_id, - "type": state.type, - "state_key": state.state_key, - "event_id": state.event_id, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, } - for state in state_events.values() + for key, state_id in state_event_ids.items() ], ) state_groups[event.event_id] = state_group @@ -248,6 +267,31 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups, types) + state_event_map = yield self.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, types): + event_to_groups = yield self._get_state_group_for_events( + event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + event_to_state = { event_id: group_to_state[group] for event_id, group in event_to_groups.items() @@ -272,6 +316,23 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, types=None): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], types) + defer.returnValue(state_map[event_id]) + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( @@ -428,20 +489,13 @@ class StateStore(SQLBaseStore): full=(types is None), ) - state_events = yield self._get_events( - [ev_id for sd in results.values() for ev_id in sd.values()], - get_prev_content=False - ) - - state_events = {e.event_id: e for e in state_events} - # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. for group, state_dict in results.items(): results[group] = { - key: state_events[event_id] + key: event_id for key, event_id in state_dict.items() - if event_id and event_id in state_events + if event_id } defer.returnValue(results) |