From 5dc2a702cf2415a55bab8061053c0b47dc21dc93 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Aug 2016 16:54:40 +0100 Subject: Make _state_groups_id_gen a normal IdGenerator --- synapse/storage/__init__.py | 2 +- synapse/storage/events.py | 83 +++++++++++++++++++++------------------------ synapse/storage/state.py | 3 -- 3 files changed, 40 insertions(+), 48 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4f4c723c5b..6c32773f25 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -115,7 +115,7 @@ class DataStore(RoomMemberStore, RoomStore, ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5cbe8c5978..bc1bc97e19 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -271,39 +271,35 @@ class EventsStore(SQLBaseStore): len(events_and_contexts) ) - state_group_id_manager = self._state_groups_id_gen.get_next_mult( - len(events_and_contexts) - ) with stream_ordering_manager as stream_orderings: - with state_group_id_manager as state_group_ids: - for (event, context), stream, state_group_id in zip( - events_and_contexts, stream_orderings, state_group_ids - ): - event.internal_metadata.stream_ordering = stream - # Assign a state group_id in case a new id is needed for - # this context. In theory we only need to assign this - # for contexts that have current_state and aren't outliers - # but that make the code more complicated. Assigning an ID - # per event only causes the state_group_ids to grow as fast - # as the stream_ordering so in practise shouldn't be a problem. - context.new_state_group_id = state_group_id - - chunks = [ - events_and_contexts[x:x + 100] - for x in xrange(0, len(events_and_contexts), 100) - ] + for (event, context), stream, in zip( + events_and_contexts, stream_orderings + ): + event.internal_metadata.stream_ordering = stream + # Assign a state group_id in case a new id is needed for + # this context. In theory we only need to assign this + # for contexts that have current_state and aren't outliers + # but that make the code more complicated. Assigning an ID + # per event only causes the state_group_ids to grow as fast + # as the stream_ordering so in practise shouldn't be a problem. + context.new_state_group_id = self._state_groups_id_gen.get_next() + + chunks = [ + events_and_contexts[x:x + 100] + for x in xrange(0, len(events_and_contexts), 100) + ] - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc_by(len(chunk)) + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + delete_existing=delete_existing, + ) + persist_event_counter.inc_by(len(chunk)) @_retry_on_integrity_error @defer.inlineCallbacks @@ -312,19 +308,18 @@ class EventsStore(SQLBaseStore): delete_existing=False): try: with self._stream_id_gen.get_next() as stream_ordering: - with self._state_groups_id_gen.get_next() as state_group_id: - event.internal_metadata.stream_ordering = stream_ordering - context.new_state_group_id = state_group_id - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() + event.internal_metadata.stream_ordering = stream_ordering + context.new_state_group_id = self._state_groups_id_gen.get_next() + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + current_state=current_state, + backfilled=backfilled, + delete_existing=delete_existing, + ) + persist_event_counter.inc() except _RollbackButIsFineException: pass diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b1d461fef5..0442353287 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -526,6 +526,3 @@ class StateStore(SQLBaseStore): return self.runInteraction( "get_all_new_state_groups", get_all_new_state_groups_txn ) - - def get_state_stream_token(self): - return self._state_groups_id_gen.get_current_token() -- cgit 1.4.1 From 1bb8ec296de4f6bf37b0b31f28bc0d7285f96ba0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 31 Aug 2016 10:09:46 +0100 Subject: Generate state group ids in state layer --- synapse/state.py | 9 ++++++--- synapse/storage/events.py | 10 +--------- synapse/storage/state.py | 24 +++++++++++++++++------- 3 files changed, 24 insertions(+), 19 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/state.py b/synapse/state.py index daec983dc9..147416fd81 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -160,14 +160,14 @@ class StateHandler(object): else: context.current_state_ids = {} context.prev_state_events = [] - context.state_group = None + context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: context.current_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = None + context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) @@ -193,7 +193,10 @@ class StateHandler(object): group, curr_state = ret context.current_state_ids = curr_state - context.state_group = group if not event.is_state() else None + if event.is_state() or group is None: + context.state_group = self.store.get_next_state_group() + else: + context.state_group = group if event.is_state(): key = (event.type, event.state_key) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bc1bc97e19..1a7d4c5199 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -276,13 +276,6 @@ class EventsStore(SQLBaseStore): events_and_contexts, stream_orderings ): event.internal_metadata.stream_ordering = stream - # Assign a state group_id in case a new id is needed for - # this context. In theory we only need to assign this - # for contexts that have current_state and aren't outliers - # but that make the code more complicated. Assigning an ID - # per event only causes the state_group_ids to grow as fast - # as the stream_ordering so in practise shouldn't be a problem. - context.new_state_group_id = self._state_groups_id_gen.get_next() chunks = [ events_and_contexts[x:x + 100] @@ -309,7 +302,6 @@ class EventsStore(SQLBaseStore): try: with self._stream_id_gen.get_next() as stream_ordering: event.internal_metadata.stream_ordering = stream_ordering - context.new_state_group_id = self._state_groups_id_gen.get_next() yield self.runInteraction( "persist_event", self._persist_event_txn, @@ -523,7 +515,7 @@ class EventsStore(SQLBaseStore): # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group or context.new_state_group_id + state_group_id = context.state_group self._simple_insert_txn( txn, table="ex_outlier_stream", diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0442353287..56bfdc0b55 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,6 +83,14 @@ class StateStore(SQLBaseStore): for group, event_id_map in group_to_ids.items() }) + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups_state WHERE state_group = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: @@ -92,8 +100,10 @@ class StateStore(SQLBaseStore): if context.current_state_ids is None: continue - if context.state_group is not None: - state_groups[event.event_id] = context.state_group + state_groups[event.event_id] = context.state_group + + if self._have_persisted_state_group_txn(txn, context.state_group): + logger.info("Already persisted state_group: %r", context.state_group) continue state_event_ids = dict(context.current_state_ids) @@ -101,13 +111,11 @@ class StateStore(SQLBaseStore): if event.is_state(): state_event_ids[(event.type, event.state_key)] = event.event_id - state_group = context.new_state_group_id - self._simple_insert_txn( txn, table="state_groups", values={ - "id": state_group, + "id": context.state_group, "room_id": event.room_id, "event_id": event.event_id, }, @@ -118,7 +126,7 @@ class StateStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": state_group, + "state_group": context.state_group, "room_id": event.room_id, "type": key[0], "state_key": key[1], @@ -127,7 +135,6 @@ class StateStore(SQLBaseStore): for key, state_id in state_event_ids.items() ], ) - state_groups[event.event_id] = state_group self._simple_insert_many_txn( txn, @@ -526,3 +533,6 @@ class StateStore(SQLBaseStore): return self.runInteraction( "get_all_new_state_groups", get_all_new_state_groups_txn ) + + def get_next_state_group(self): + return self._state_groups_id_gen.get_next() -- cgit 1.4.1 From c10cb581c6ce54e7dfa1f8a0f6449ee7f6d049d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 31 Aug 2016 13:55:02 +0100 Subject: Correctly handle the difference between prev and current state --- synapse/api/auth.py | 4 +-- synapse/events/snapshot.py | 5 ++-- synapse/handlers/federation.py | 31 +++++++++++++++------ synapse/handlers/message.py | 10 +++---- synapse/handlers/room_member.py | 6 ++-- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/state.py | 31 +++++++++++++++------ synapse/storage/roommember.py | 27 +++++++++++++++--- synapse/storage/state.py | 3 -- tests/replication/slave/storage/test_events.py | 4 ++- tests/replication/test_resource.py | 10 ++----- tests/test_state.py | 38 ++++++++++---------------- 12 files changed, 102 insertions(+), 69 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f26e585623..fcf0b0d25f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -66,7 +66,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): auth_events_ids = yield self.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -852,7 +852,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) + auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index c75afd02d8..e895b1c450 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,8 +15,9 @@ class EventContext(object): - def __init__(self, current_state_ids=None): - self.current_state_ids = current_state_ids + def __init__(self): + self.current_state_ids = None + self.prev_state_ids = None self.state_group = None self.rejected = False self.push_actions = [] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a7ea8fb98f..8e61d74b13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -222,7 +222,7 @@ class FederationHandler(BaseHandler): # joined the room. Don't bother if the user is just # changing their profile info. newly_joined = True - prev_state_id = context.current_state_ids.get( + prev_state_id = context.prev_state_ids.get( (event.type, event.state_key) ) if prev_state_id: @@ -835,12 +835,12 @@ class FederationHandler(BaseHandler): self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = context.current_state_ids.values() + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) - state = yield self.store.get_events(context.current_state_ids.values()) + state = yield self.store.get_events(context.prev_state_ids.values()) defer.returnValue({ "state": state.values(), @@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler): if not auth_events: auth_events_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler): context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() }) - context.state_group = None + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. auth_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler): context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() }) - context.state_group = None + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) @@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"] ) original_invite = None - original_invite_id = context.current_state_ids.get(key) + original_invite_id = context.prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( original_invite_id, allow_none=True @@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event_id = context.current_state_ids.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2f4387f60..3577db0595 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -272,7 +272,7 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event_id = context.current_state_ids.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -808,8 +808,8 @@ class MessageHandler(BaseHandler): event = builder.build() logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state_ids, + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, ) defer.returnValue( @@ -904,7 +904,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Redaction: auth_events_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -924,7 +924,7 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.current_state_ids: + if event.type == EventTypes.Create and context.prev_state_ids: raise AuthError( 403, "Changing the room create event is forbidden", diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index dd4b90ee24..3ba5335af7 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.current_state_ids.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) @@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler): if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(context.current_state_ids) + guest_can_join = yield self._can_guest_join(context.prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.current_state_ids.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, event.state_key), None ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 51cb21ee9d..6ff9a06de1 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,7 +87,7 @@ class BulkPushRuleEvaluator: ) room_members = yield self.store.get_joined_users_from_context( - event.room_id, context.state_group, context.current_state_ids + event, context ) evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) diff --git a/synapse/state.py b/synapse/state.py index 147416fd81..a0f807e3b9 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -128,7 +128,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_context( + joined_users = yield self.store.get_joined_users_from_state( room_id, group, state_ids ) defer.returnValue(joined_users) @@ -154,27 +154,38 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } + if event.is_state(): + context.current_state_events = dict(context.prev_state_ids) + key = (event.type, event.state_key) + context.current_state_events[key] = event.event_id + else: + context.current_state_events = context.prev_state_ids else: context.current_state_ids = {} + context.prev_state_ids = {} context.prev_state_events = [] context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) @@ -192,7 +203,7 @@ class StateHandler(object): group, curr_state = ret - context.current_state_ids = curr_state + context.prev_state_ids = curr_state if event.is_state() or group is None: context.state_group = self.store.get_next_state_group() else: @@ -200,9 +211,13 @@ class StateHandler(object): if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index cab1660830..6ab10db328 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore): desc="who_forgot" ) - def get_joined_users_from_context(self, room_id, state_group, state_ids): + def get_joined_users_from_context(self, event, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_ids + event.room_id, state_group, context.current_state_ids, event=event, + ) + + def get_joined_users_from_state(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids, ) @cachedInlineCallbacks(num_args=2, cache_context=True) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context): + cache_context, event=None): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. @@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore): desc="_get_joined_users_from_context", ) - defer.returnValue(set(row["user_id"] for row in rows)) + users_in_room = set(row["user_id"] for row in rows) + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room.add(event.state_key) + + defer.returnValue(users_in_room) def is_host_joined(self, room_id, host, state_group, state_ids): if not state_group: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 56bfdc0b55..dce5a2f135 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -108,9 +108,6 @@ class StateStore(SQLBaseStore): state_event_ids = dict(context.current_state_ids) - if event.is_state(): - state_event_ids[(event.type, event.state_key)] = event.event_id - self._simple_insert_txn( txn, table="state_groups", diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 218cb24889..44e859b5d1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): else: state_ids = None - context = EventContext(current_state_ids=state_ids) + context = EventContext() + context.current_state_ids = state_ids + context.prev_state_ids = state_ids context.push_actions = push_actions ordering = None diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index e70ac6f14d..b69832cc1b 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events_and_state(self): - get = self.get(events="-1", state="-1", timeout="0") + def test_events(self): + get = self.get(events="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( synapse.types.create_requester(self.user), {} ) @@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body["events"]["field_names"], [ "position", "internal", "json", "state_group" ]) - self.assertEquals(body["state_groups"]["field_names"], [ - "position", "room_id", "event_id" - ]) - self.assertEquals(body["state_group_state"]["field_names"], [ - "position", "type", "state_key", "event_id" - ]) @defer.inlineCallbacks def test_presence(self): diff --git a/tests/test_state.py b/tests/test_state.py index de2d35145a..6454f994e3 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -86,17 +86,8 @@ class StateGroupStore(object): state_events = dict(context.current_state_ids) - if event.is_state(): - state_events[(event.type, event.state_key)] = event.event_id - - state_group = context.state_group - if not state_group: - state_group = self._next_group - self._next_group += 1 - - self._group_to_state[state_group] = state_events - - self._event_to_state_group[event.event_id] = state_group + self._group_to_state[context.state_group] = state_events + self._event_to_state_group[event.event_id] = context.state_group def get_events(self, event_ids, **kwargs): return { @@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase): "get_state_groups_ids", "add_event_hashes", "get_events", + "get_next_state_group", ] ) hs = Mock(spec_set=[ @@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + self.store.get_next_state_group.side_effect = Mock + self.state = StateHandler(hs) self.event_id = 0 @@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state_ids)) + self.assertEqual(2, len(context_store["D"].prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].current_state_ids.values()} + {e_id for e_id in context_store["D"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].current_state_ids.values()} + {e for e in context_store["E"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].current_state_ids.values()} + {e for e in context_store["D"].prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase): set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) - @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") @@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase): self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): @@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): @@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): -- cgit 1.4.1 From 0cfd6c316193895b6fd5c761253305cf8688cef1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 31 Aug 2016 16:25:57 +0100 Subject: Use state_groups table to test existence --- synapse/storage/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dce5a2f135..ec551b0b4f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -85,7 +85,7 @@ class StateStore(SQLBaseStore): def _have_persisted_state_group_txn(self, txn, state_group): txn.execute( - "SELECT count(*) FROM state_groups_state WHERE state_group = ?", + "SELECT count(*) FROM state_groups WHERE id = ?", (state_group,) ) row = txn.fetchone() -- cgit 1.4.1