From a3dc1e9cbe491aa981b8bbaeb2414b4ec8e5b9ca Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 17:32:22 +0100 Subject: Replace context.current_state with context.current_state_ids --- synapse/api/auth.py | 68 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 24 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0db26fcfd7..40c3e9db0d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -52,7 +52,7 @@ class Auth(object): self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 # Docs for these currently lives at - # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst + # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # In addition, we have type == delete_pusher which grants access only to # delete pushers. self._KNOWN_CAVEAT_PREFIXES = set([ @@ -63,6 +63,17 @@ class Auth(object): "user_id = ", ]) + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.current_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=False) + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. @@ -847,7 +858,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = self.compute_auth_events(builder, context.current_state) + auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -855,30 +866,32 @@ class Auth(object): builder.auth_events = auth_events_entries - def compute_auth_events(self, event, current_state): + @defer.inlineCallbacks + def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - return [] + defer.returnValue([]) auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = current_state.get(key) + power_level_event_id = current_state_ids.get(key) - if power_level_event: - auth_ids.append(power_level_event.event_id) + if power_level_event_id: + auth_ids.append(power_level_event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = current_state.get(key) + join_rule_event_id = current_state_ids.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = current_state.get(key) + member_event_id = current_state_ids.get(key) key = (EventTypes.Create, "", ) - create_event = current_state.get(key) - if create_event: - auth_ids.append(create_event.event_id) + create_event_id = current_state_ids.get(key) + if create_event_id: + auth_ids.append(create_event_id) - if join_rule_event: + if join_rule_event_id: + join_rule_event = yield self.store.get_event(join_rule_event_id) join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False else: @@ -887,15 +900,21 @@ class Auth(object): if event.type == EventTypes.Member: e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event: - auth_ids.append(join_rule_event.event_id) + if join_rule_event_id: + auth_ids.append(join_rule_event_id) if e_type == Membership.JOIN: - if member_event and not is_public: - auth_ids.append(member_event.event_id) + if member_event_id and not is_public: + auth_ids.append(member_event_id) else: - if member_event: - auth_ids.append(member_event.event_id) + if member_event_id: + auth_ids.append(member_event_id) + + if for_verification: + key = (EventTypes.Member, event.state_key, ) + existing_event_id = current_state_ids.get(key) + if existing_event_id: + auth_ids.append(existing_event_id) if e_type == Membership.INVITE: if "third_party_invite" in event.content: @@ -903,14 +922,15 @@ class Auth(object): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - third_party_invite = current_state.get(key) - if third_party_invite: - auth_ids.append(third_party_invite.event_id) - elif member_event: + third_party_invite_id = current_state_ids.get(key) + if third_party_invite_id: + auth_ids.append(third_party_invite_id) + elif member_event_id: + member_event = yield self.store.get_event(member_event_id) if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - return auth_ids + defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) -- cgit 1.5.1 From 142983b4eafc93fc42a889b579d10e2b78199c48 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 25 Aug 2016 18:06:05 +0100 Subject: APP_SERVICE_PREFIX is never used; don't bother --- synapse/api/urls.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/api') diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 0fd9b7f244..91a33a3402 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" -APP_SERVICE_PREFIX = "/_matrix/appservice/v1" -- cgit 1.5.1 From 0e1900d8193a612d6920a9eca0aec4813e17d355 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 18:15:51 +0100 Subject: Pull out full state less --- synapse/api/auth.py | 13 +++++++------ synapse/state.py | 12 ++++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 40c3e9db0d..23a928de16 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -278,18 +278,19 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): - curr_state = yield self.state.get_current_state(room_id) + curr_state_id = yield self.state.get_current_state_ids(room_id) - for event in curr_state.values(): - if event.type == EventTypes.Member: + for (etype, state_key), event_id in curr_state_id.items(): + if etype == EventTypes.Member: try: - if get_domain_from_id(event.state_key) != host: + if get_domain_from_id(state_key) != host: continue except: - logger.warn("state_key not user_id: %s", event.state_key) + logger.warn("state_key not user_id: %s", state_key) continue - if event.content["membership"] == Membership.JOIN: + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: defer.returnValue(True) defer.returnValue(False) diff --git a/synapse/state.py b/synapse/state.py index 2a01887a67..78461215ca 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -95,15 +95,19 @@ class StateHandler(object): _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + if event_type: + event_id = state.get((event_type, state_key)) + event = None + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + defer.returnValue(event) + return + state_map = yield self.store.get_events(state.values(), get_prev_content=False) state = { key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } - if event_type: - defer.returnValue(state.get((event_type, state_key))) - return - defer.returnValue(state) @defer.inlineCallbacks -- cgit 1.5.1 From 1294d4a3299faba9b5f09ec6f452dfb2ab9f5e35 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 25 Aug 2016 18:34:47 +0100 Subject: Move ThirdPartyEntityKind into api.constants so the expectation becomes that the value is significant --- synapse/api/constants.py | 5 +++++ synapse/appservice/api.py | 2 +- synapse/rest/client/v2_alpha/thirdparty.py | 2 +- synapse/types.py | 7 ------- 4 files changed, 7 insertions(+), 9 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8cf4d6169c..a8123cddcb 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -85,3 +85,8 @@ class RoomCreationPreset(object): PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + + +class ThirdPartyEntityKind(object): + USER = "user" + LOCATION = "location" diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 632dc1a4f8..24253e7785 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,11 +14,11 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event from synapse.util.caches.response_cache import ResponseCache -from synapse.types import ThirdPartyEntityKind import logging import urllib diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index bbc3e9b962..4f6f1a7e17 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -18,8 +18,8 @@ import logging from twisted.internet import defer +from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet -from synapse.types import ThirdPartyEntityKind from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/types.py b/synapse/types.py index fd17ecbbe0..5349b0c450 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -269,10 +269,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) - - -# Some arbitrary constants used for internal API enumerations. Don't rely on -# exact values; always pass or compare symbolically -class ThirdPartyEntityKind(object): - USER = 'user' - LOCATION = 'location' -- cgit 1.5.1 From 25414b44a2edbf7cc66e46968e81b56ae32f2887 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Aug 2016 10:47:00 +0100 Subject: Add measure on check_host_in_room --- synapse/api/auth.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 23a928de16..597631a88f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -278,20 +278,21 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): - curr_state_id = yield self.state.get_current_state_ids(room_id) + with Measure(self.clock, "check_host_in_room"): + curr_state_id = yield self.state.get_current_state_ids(room_id) - for (etype, state_key), event_id in curr_state_id.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: + for (etype, state_key), event_id in curr_state_id.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - event = yield self.store.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) defer.returnValue(False) -- cgit 1.5.1 From 1ccdc1e93a5ae854fa89751a78c9103940a9f9e6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Aug 2016 10:59:40 +0100 Subject: Cache check_host_in_room --- synapse/api/auth.py | 20 ++++++-------------- synapse/storage/roommember.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 14 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 597631a88f..f26e585623 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -279,22 +279,14 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - curr_state_id = yield self.state.get_current_state_ids(room_id) + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - for (etype, state_key), event_id in curr_state_id.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - event = yield self.store.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + group, curr_state_ids = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - defer.returnValue(False) + ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) + defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): key = (EventTypes.Member, event.user_id, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 2cab065bca..5ce5e8da37 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -400,3 +400,38 @@ class RoomMemberStore(SQLBaseStore): ) defer.returnValue(set(row["user_id"] for row in rows)) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) -- cgit 1.5.1 From c10cb581c6ce54e7dfa1f8a0f6449ee7f6d049d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 31 Aug 2016 13:55:02 +0100 Subject: Correctly handle the difference between prev and current state --- synapse/api/auth.py | 4 +-- synapse/events/snapshot.py | 5 ++-- synapse/handlers/federation.py | 31 +++++++++++++++------ synapse/handlers/message.py | 10 +++---- synapse/handlers/room_member.py | 6 ++-- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/state.py | 31 +++++++++++++++------ synapse/storage/roommember.py | 27 +++++++++++++++--- synapse/storage/state.py | 3 -- tests/replication/slave/storage/test_events.py | 4 ++- tests/replication/test_resource.py | 10 ++----- tests/test_state.py | 38 ++++++++++---------------- 12 files changed, 102 insertions(+), 69 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f26e585623..fcf0b0d25f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -66,7 +66,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): auth_events_ids = yield self.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -852,7 +852,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) + auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index c75afd02d8..e895b1c450 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,8 +15,9 @@ class EventContext(object): - def __init__(self, current_state_ids=None): - self.current_state_ids = current_state_ids + def __init__(self): + self.current_state_ids = None + self.prev_state_ids = None self.state_group = None self.rejected = False self.push_actions = [] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a7ea8fb98f..8e61d74b13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -222,7 +222,7 @@ class FederationHandler(BaseHandler): # joined the room. Don't bother if the user is just # changing their profile info. newly_joined = True - prev_state_id = context.current_state_ids.get( + prev_state_id = context.prev_state_ids.get( (event.type, event.state_key) ) if prev_state_id: @@ -835,12 +835,12 @@ class FederationHandler(BaseHandler): self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = context.current_state_ids.values() + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) - state = yield self.store.get_events(context.current_state_ids.values()) + state = yield self.store.get_events(context.prev_state_ids.values()) defer.returnValue({ "state": state.values(), @@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler): if not auth_events: auth_events_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler): context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() }) - context.state_group = None + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. auth_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler): context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() }) - context.state_group = None + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) @@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"] ) original_invite = None - original_invite_id = context.current_state_ids.get(key) + original_invite_id = context.prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( original_invite_id, allow_none=True @@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event_id = context.current_state_ids.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2f4387f60..3577db0595 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -272,7 +272,7 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event_id = context.current_state_ids.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -808,8 +808,8 @@ class MessageHandler(BaseHandler): event = builder.build() logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state_ids, + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, ) defer.returnValue( @@ -904,7 +904,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Redaction: auth_events_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -924,7 +924,7 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.current_state_ids: + if event.type == EventTypes.Create and context.prev_state_ids: raise AuthError( 403, "Changing the room create event is forbidden", diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index dd4b90ee24..3ba5335af7 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.current_state_ids.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) @@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler): if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(context.current_state_ids) + guest_can_join = yield self._can_guest_join(context.prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.current_state_ids.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, event.state_key), None ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 51cb21ee9d..6ff9a06de1 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,7 +87,7 @@ class BulkPushRuleEvaluator: ) room_members = yield self.store.get_joined_users_from_context( - event.room_id, context.state_group, context.current_state_ids + event, context ) evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) diff --git a/synapse/state.py b/synapse/state.py index 147416fd81..a0f807e3b9 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -128,7 +128,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_context( + joined_users = yield self.store.get_joined_users_from_state( room_id, group, state_ids ) defer.returnValue(joined_users) @@ -154,27 +154,38 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } + if event.is_state(): + context.current_state_events = dict(context.prev_state_ids) + key = (event.type, event.state_key) + context.current_state_events[key] = event.event_id + else: + context.current_state_events = context.prev_state_ids else: context.current_state_ids = {} + context.prev_state_ids = {} context.prev_state_events = [] context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) @@ -192,7 +203,7 @@ class StateHandler(object): group, curr_state = ret - context.current_state_ids = curr_state + context.prev_state_ids = curr_state if event.is_state() or group is None: context.state_group = self.store.get_next_state_group() else: @@ -200,9 +211,13 @@ class StateHandler(object): if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index cab1660830..6ab10db328 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore): desc="who_forgot" ) - def get_joined_users_from_context(self, room_id, state_group, state_ids): + def get_joined_users_from_context(self, event, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_ids + event.room_id, state_group, context.current_state_ids, event=event, + ) + + def get_joined_users_from_state(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids, ) @cachedInlineCallbacks(num_args=2, cache_context=True) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context): + cache_context, event=None): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. @@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore): desc="_get_joined_users_from_context", ) - defer.returnValue(set(row["user_id"] for row in rows)) + users_in_room = set(row["user_id"] for row in rows) + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room.add(event.state_key) + + defer.returnValue(users_in_room) def is_host_joined(self, room_id, host, state_group, state_ids): if not state_group: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 56bfdc0b55..dce5a2f135 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -108,9 +108,6 @@ class StateStore(SQLBaseStore): state_event_ids = dict(context.current_state_ids) - if event.is_state(): - state_event_ids[(event.type, event.state_key)] = event.event_id - self._simple_insert_txn( txn, table="state_groups", diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 218cb24889..44e859b5d1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): else: state_ids = None - context = EventContext(current_state_ids=state_ids) + context = EventContext() + context.current_state_ids = state_ids + context.prev_state_ids = state_ids context.push_actions = push_actions ordering = None diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index e70ac6f14d..b69832cc1b 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events_and_state(self): - get = self.get(events="-1", state="-1", timeout="0") + def test_events(self): + get = self.get(events="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( synapse.types.create_requester(self.user), {} ) @@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body["events"]["field_names"], [ "position", "internal", "json", "state_group" ]) - self.assertEquals(body["state_groups"]["field_names"], [ - "position", "room_id", "event_id" - ]) - self.assertEquals(body["state_group_state"]["field_names"], [ - "position", "type", "state_key", "event_id" - ]) @defer.inlineCallbacks def test_presence(self): diff --git a/tests/test_state.py b/tests/test_state.py index de2d35145a..6454f994e3 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -86,17 +86,8 @@ class StateGroupStore(object): state_events = dict(context.current_state_ids) - if event.is_state(): - state_events[(event.type, event.state_key)] = event.event_id - - state_group = context.state_group - if not state_group: - state_group = self._next_group - self._next_group += 1 - - self._group_to_state[state_group] = state_events - - self._event_to_state_group[event.event_id] = state_group + self._group_to_state[context.state_group] = state_events + self._event_to_state_group[event.event_id] = context.state_group def get_events(self, event_ids, **kwargs): return { @@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase): "get_state_groups_ids", "add_event_hashes", "get_events", + "get_next_state_group", ] ) hs = Mock(spec_set=[ @@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + self.store.get_next_state_group.side_effect = Mock + self.state = StateHandler(hs) self.event_id = 0 @@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state_ids)) + self.assertEqual(2, len(context_store["D"].prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].current_state_ids.values()} + {e_id for e_id in context_store["D"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].current_state_ids.values()} + {e for e in context_store["E"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].current_state_ids.values()} + {e for e in context_store["D"].prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase): set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) - @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") @@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase): self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): @@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): @@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): -- cgit 1.5.1 From ed7a703d4c61feaae437cd4bc11c2afea2dc4ad4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 31 Aug 2016 15:53:19 +0100 Subject: Handle the fact that workers can't generate state groups --- synapse/api/auth.py | 6 ++-- synapse/state.py | 81 ++++++++++++++++++++++++++++++++++++----------------- 2 files changed, 60 insertions(+), 27 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index fcf0b0d25f..dcda40863f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -281,11 +281,13 @@ class Auth(object): with Measure(self.clock, "check_host_in_room"): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - group, curr_state_ids = yield self.state.resolve_state_groups( + entry = yield self.state.resolve_state_groups( room_id, latest_event_ids ) - ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) + ret = yield self.store.is_host_joined( + room_id, host, entry.state_group, entry.state + ) defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): diff --git a/synapse/state.py b/synapse/state.py index 4d48cc4605..b31bbcdbd2 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -43,11 +43,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 +_NEXT_STATE_ID = 1 + + +def _gen_state_id(): + global _NEXT_STATE_ID + s = "X%d" % (_NEXT_STATE_ID,) + _NEXT_STATE_ID += 1 + return s + + class _StateCacheEntry(object): - def __init__(self, state, state_group, ts): + __slots__ = ["state", "state_group", "state_id"] + + def __init__(self, state, state_group): self.state = state self.state_group = state_group + # The `state_id` is a unique ID we generate that can be used as ID for + # this collection of state. Usually this would be the same as the + # state group, but on worker instances we can't generate a new state + # group each time we resolve state, so we generate a separate one that + # isn't persisted and is used solely for caches. + # `state_id` is either a state_group (and so an int) or a string. This + # ensures we don't accidentally persist a state_id as a stateg_group + if state_group: + self.state_id = state_group + else: + self.state_id = _gen_state_id() + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -93,7 +117,8 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state if event_type: event_id = state.get((event_type, state_key)) @@ -116,7 +141,8 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state if event_type: defer.returnValue(state.get((event_type, state_key))) @@ -127,9 +153,9 @@ class StateHandler(object): @defer.inlineCallbacks def get_current_user_in_room(self, room_id): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) + entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( - room_id, group, state_ids + room_id, entry.state_id, entry.state ) defer.returnValue(joined_users) @@ -191,23 +217,26 @@ class StateHandler(object): defer.returnValue(context) if event.is_state(): - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], ) - group, curr_state = ret + curr_state = entry.state context.prev_state_ids = curr_state - if event.is_state() or group is None: + if event.is_state(): context.state_group = self.store.get_next_state_group() else: - context.state_group = group + if entry.state_group is None: + entry.state_group = self.store.get_next_state_group() + entry.state_id = entry.state_group + context.state_group = entry.state_group if event.is_state(): key = (event.type, event.state_key) @@ -249,16 +278,15 @@ class StateHandler(object): if len(group_names) == 1: name, state_list = state_groups_ids.items().pop() - defer.returnValue((name, state_list,)) + defer.returnValue(_StateCacheEntry( + state=state_list, + state_group=name, + )) if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: - cache.ts = self.clock.time_msec() - - defer.returnValue( - (cache.state_group, cache.state,) - ) + defer.returnValue(cache) logger.info( "Resolving state for %s with %d groups", room_id, len(state_groups_ids) @@ -302,19 +330,22 @@ class StateHandler(object): if new_state_event_ids == frozenset(e_id for e_id in events): state_group = sg break - if not state_group: - state_group = self.store.get_next_state_group() + if state_group is None: + # Worker instances don't have access to this method, but we want + # to set the state_group on the main instance to increase cache + # hits. + if hasattr(self.store, "get_next_state_group"): + state_group = self.store.get_next_state_group() + + cache = _StateCacheEntry( + state=new_state, + state_group=state_group, + ) if self._state_cache is not None: - cache = _StateCacheEntry( - state=new_state, - state_group=state_group, - ts=self.clock.time_msec() - ) - self._state_cache[group_names] = cache - defer.returnValue((state_group, new_state,)) + defer.returnValue(cache) def resolve_events(self, state_sets, event): logger.info( -- cgit 1.5.1