From 80472ac1988da937f8318651fe5fc764ea869d35 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Oct 2014 10:04:55 +0100 Subject: Add missing package storate.state --- synapse/storage/state.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 synapse/storage/state.py (limited to 'synapse/storage/state.py') diff --git a/synapse/storage/state.py b/synapse/storage/state.py new file mode 100644 index 0000000000..b910646d74 --- /dev/null +++ b/synapse/storage/state.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +from collections import namedtuple + + +StateGroup = namedtuple("StateGroup", ("group", "state")) + + +class StateStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_state_groups(self, event_ids): + groups = set() + for event_id in event_ids: + group = yield self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + groups.add(group) + + res = [] + for group in groups: + state_ids = yield self._simple_select_onecol( + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + state = [] + for state_id in state_ids: + s = yield self.get_event( + state_id, + allow_none=True, + ) + if s: + state.append(s) + + res.append(StateGroup(group, state)) + + defer.returnValue(res) + + def store_state_groups(self, event): + return self.runInteraction( + self._store_state_groups_txn, event + ) + + def _store_state_groups_txn(self, txn, event): + state_group = event.state_group + if not state_group: + state_group = self._simple_insert_txn( + txn, + table="state_groups", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + } + ) + + for state in event.state_events: + self._simple_insert_txn( + txn, + table="state_groups_state", + values={ + "state_group": state_group, + "room_id": state.room_id, + "type": state.type, + "state_key": state.state_key, + "event_id": state.event_id, + } + ) + + self._simple_insert_txn( + txn, + table="event_to_state_groups", + values={ + "state_group": state_group, + "event_id": event.event_id, + } + ) -- cgit 1.4.1 From e7bc1291a079224315cea5c756061ad711241be1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Oct 2014 16:06:59 +0100 Subject: Begin making auth use event.old_state_events --- synapse/api/auth.py | 113 ++++++++++++++++++++++++--------------- synapse/handlers/_base.py | 10 +++- synapse/handlers/directory.py | 5 +- synapse/handlers/federation.py | 11 +++- synapse/handlers/message.py | 13 ++--- synapse/handlers/profile.py | 5 +- synapse/handlers/room.py | 19 ++++--- synapse/state.py | 18 +++---- synapse/storage/schema/state.sql | 2 +- synapse/storage/state.py | 2 +- 10 files changed, 115 insertions(+), 83 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e1b1823cd7..d951cb265b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,6 +21,7 @@ from synapse.api.constants import Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, + RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, ) from synapse.util.logutils import log_function @@ -55,11 +56,7 @@ class Auth(object): defer.returnValue(allowed) return - self._check_joined_room( - member=snapshot.membership_state, - user_id=snapshot.user_id, - room_id=snapshot.room_id, - ) + self.check_event_sender_in_room(event) if is_state: # TODO (erikj): This really only should be called for *new* @@ -98,6 +95,16 @@ class Auth(object): pass defer.returnValue(None) + def check_event_sender_in_room(self, event): + key = (RoomMemberEvent.TYPE, event.user_id, ) + member_event = event.state_events.get(key) + + return self._check_joined_room( + member_event, + event.user_id, + event.room_id + ) + def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( @@ -114,29 +121,39 @@ class Auth(object): raise AuthError(403, "Room does not exist") # get info about the caller - try: - caller = yield self.store.get_room_member( - user_id=event.user_id, - room_id=event.room_id) - except: - caller = None + key = (RoomMemberEvent.TYPE, event.user_id, ) + caller = event.old_state_events.get(key) + caller_in_room = caller and caller.membership == "join" # get info about the target - try: - target = yield self.store.get_room_member( - user_id=target_user_id, - room_id=event.room_id) - except: - target = None + key = (RoomMemberEvent.TYPE, target_user_id, ) + target = event.old_state_events.get(key) + target_in_room = target and target.membership == "join" membership = event.content["membership"] - join_rule = yield self.store.get_room_join_rule(event.room_id) - if not join_rule: + key = (RoomJoinRulesEvent.TYPE, "", ) + join_rule_event = event.old_state_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: join_rule = JoinRules.INVITE + user_level = self._get_power_level_from_event_state( + event, + event.user_id, + ) + + ban_level, kick_level, redact_level = ( + yield self._get_ops_level_from_event_state( + event + ) + ) + if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently # PRIVATE join rules. @@ -171,29 +188,16 @@ class Auth(object): if not caller_in_room: # trying to leave a room you aren't joined raise AuthError(403, "You are not in room %s." % event.room_id) elif target_user_id != event.user_id: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - _, kick_level, _ = yield self.store.get_ops_levels(event.room_id) - if kick_level: kick_level = int(kick_level) else: - kick_level = 50 + kick_level = 50 # FIXME (erikj): What should we do here? if user_level < kick_level: raise AuthError( 403, "You cannot kick user %s." % target_user_id ) elif Membership.BAN == membership: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - - ban_level, _, _ = yield self.store.get_ops_levels(event.room_id) - if ban_level: ban_level = int(ban_level) else: @@ -206,6 +210,29 @@ class Auth(object): defer.returnValue(True) + def _get_power_level_from_event_state(self, event, user_id): + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + level = None + if power_level_event: + level = power_level_event.content[user_id] + if not level: + level = power_level_event.content["default"] + + return level + + def _get_ops_level_from_event_state(self, event): + key = (RoomOpsPowerLevelsEvent.TYPE, "", ) + ops_event = event.old_state_events.get(key) + + if ops_event: + return ( + ops_event.content.get("ban_level"), + ops_event.content.get("kick_level"), + ops_event.content.get("redact_level"), + ) + return None, None, None, + @defer.inlineCallbacks def get_user_by_req(self, request): """ Get a registered user's ID. @@ -282,8 +309,8 @@ class Auth(object): else: send_level = 0 - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -308,8 +335,8 @@ class Auth(object): add_level = int(add_level) - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -333,8 +360,8 @@ class Auth(object): if current_state: current_state = current_state[0] - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -363,10 +390,10 @@ class Auth(object): event.user_id, ) - if user_level: - user_level = int(user_level) - else: - user_level = 0 + user_level = self._get_power_level_from_event_state( + event, + event.user_id, + ) _, _, redact_level = yield self.store.get_ops_levels(event.room_id) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index de4d23bbb3..cd6c35f194 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -44,9 +44,17 @@ class BaseHandler(object): @defer.inlineCallbacks def _on_new_room_event(self, event, snapshot, extra_destinations=[], - extra_users=[]): + extra_users=[], suppress_auth=False): snapshot.fill_out_prev_events(event) + yield self.state_handler.annotate_state_groups(event) + + if not suppress_auth: + yield self.auth.check(event, snapshot, raises=True) + + if hasattr(event, "state_key"): + yield self.state_handler.handle_new_event(event, snapshot) + yield self.store.persist_event(event) destinations = set(extra_destinations) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a56830d520..6e897e915d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -152,5 +152,6 @@ class DirectoryHandler(BaseHandler): user_id=user_id, ) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user_id]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user_id], suppress_auth=True + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f52591d2a3..44bf7def2e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -95,6 +95,8 @@ class FederationHandler(BaseHandler): logger.debug("Got event: %s", event.event_id) + yield self.state_handler.annotate_state_groups(event) + with (yield self.lock_manager.lock(pdu.context)): if event.is_state and not backfilled: is_new_state = yield self.state_handler.handle_new_state( @@ -195,7 +197,12 @@ class FederationHandler(BaseHandler): for pdu in pdus: event = self.pdu_codec.event_from_pdu(pdu) + + # FIXME (erikj): Not sure this actually works :/ + yield self.state_handler.annotate_state_groups(event) + events.append(event) + yield self.store.persist_event(event, backfilled=True) defer.returnValue(events) @@ -235,6 +242,7 @@ class FederationHandler(BaseHandler): new_event.destinations = [target_host] snapshot.fill_out_prev_events(new_event) + yield self.state_handler.annotate_state_groups(new_event) yield self.handle_new_event(new_event, snapshot) # TODO (erikj): Time out here. @@ -254,12 +262,11 @@ class FederationHandler(BaseHandler): is_public=False ) except: + # FIXME pass - defer.returnValue(True) - @log_function def _on_user_joined(self, user, room_id): waiters = self.waiting_for_join_list.get((user.to_string(), room_id), []) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 317ef2c80c..1c2cbce151 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -87,10 +87,9 @@ class MessageHandler(BaseHandler): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - if not suppress_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self._on_new_room_event(event, snapshot) + yield self._on_new_room_event( + event, snapshot, suppress_auth=suppress_auth + ) self.hs.get_handlers().presence_handler.bump_presence_active_time( user @@ -149,13 +148,9 @@ class MessageHandler(BaseHandler): state_key=event.state_key, ) - yield self.auth.check(event, snapshot, raises=True) - if stamp_event: event.content["hsob_ts"] = int(self.clock.time_msec()) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot) @defer.inlineCallbacks @@ -227,8 +222,6 @@ class MessageHandler(BaseHandler): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - yield self.auth.check(event, snapshot, raises=True) - # store message in db yield self._on_new_room_event(event, snapshot) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dab9b03f04..4cd0a06093 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -218,5 +218,6 @@ class ProfileHandler(BaseHandler): user_id=j.state_key, ) - yield self.state_handler.handle_new_event(new_event, snapshot) - yield self._on_new_room_event(new_event, snapshot) + yield self._on_new_room_event( + new_event, snapshot, suppress_auth=True + ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index c0f9a7c807..cb5bd17d2b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -129,8 +129,9 @@ class RoomCreationHandler(BaseHandler): logger.debug("Event: %s", event) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user], suppress_auth=True + ) for event in creation_events: yield handle_event(event) @@ -396,8 +397,6 @@ class RoomMemberHandler(BaseHandler): yield self._do_join(event, snapshot, do_auth=do_auth) else: # This is not a JOIN, so we can handle it normally. - if do_auth: - yield self.auth.check(event, snapshot, raises=True) # If we're banning someone, set a req power level if event.membership == Membership.BAN: @@ -419,6 +418,7 @@ class RoomMemberHandler(BaseHandler): event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) defer.returnValue({"room_id": room_id}) @@ -507,14 +507,11 @@ class RoomMemberHandler(BaseHandler): if not have_joined: logger.debug("Doing normal join") - if do_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) user = self.hs.parse_userid(event.user_id) @@ -558,7 +555,8 @@ class RoomMemberHandler(BaseHandler): defer.returnValue([r.room_id for r in rooms]) - def _do_local_membership_update(self, event, membership, snapshot): + def _do_local_membership_update(self, event, membership, snapshot, + do_auth): destinations = [] # If we're inviting someone, then we should also send it to that @@ -575,9 +573,10 @@ class RoomMemberHandler(BaseHandler): return self._on_new_room_event( event, snapshot, extra_destinations=destinations, - extra_users=[target_user] + extra_users=[target_user], suppress_auth=(not do_auth), ) + class RoomListHandler(BaseHandler): @defer.inlineCallbacks diff --git a/synapse/state.py b/synapse/state.py index 8f09b7d50a..9be6b716e2 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -71,6 +71,7 @@ class StateHandler(object): # (w.r.t. to power levels) snapshot.fill_out_prev_events(event) + yield self.annotate_state_groups(event) event.prev_events = [ e for e in event.prev_events if e != event.event_id @@ -83,8 +84,6 @@ class StateHandler(object): current_state.pdu_id, current_state.origin ) - yield self.update_state_groups(event) - # TODO check current_state to see if the min power level is less # than the power level of the user # power_level = self._get_power_level_for_event(event) @@ -131,21 +130,16 @@ class StateHandler(object): defer.returnValue(is_new) @defer.inlineCallbacks - def update_state_groups(self, event): + def annotate_state_groups(self, event): state_groups = yield self.store.get_state_groups( event.prev_events ) - if len(state_groups) == 1 and not hasattr(event, "state_key"): - event.state_group = state_groups[0].group - event.current_state = state_groups[0].state - return - state = {} state_sets = {} for group in state_groups: for s in group.state: - state.setdefault((s.type, s.state_key), []).add(s) + state.setdefault((s.type, s.state_key), []).append(s) state_sets.setdefault( (s.type, s.state_key), @@ -153,7 +147,7 @@ class StateHandler(object): ).add(s.event_id) unconflicted_state = { - k: v.pop() for k, v in state_sets.items() + k: state[k].pop() for k, v in state_sets.items() if len(v) == 1 } @@ -168,11 +162,13 @@ class StateHandler(object): for key, events in conflicted_state.items(): new_state[key] = yield self.resolve(events) + event.old_state_events = new_state + if hasattr(event, "state_key"): new_state[(event.type, event.state_key)] = event event.state_group = None - event.current_state = new_state.values() + event.state_events = new_state @defer.inlineCallbacks def resolve(self, events): diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql index b5c345fae7..b44c56b519 100644 --- a/synapse/storage/schema/state.sql +++ b/synapse/storage/schema/state.sql @@ -14,7 +14,7 @@ */ CREATE TABLE IF NOT EXISTS state_groups( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b910646d74..9496c935a7 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -74,7 +74,7 @@ class StateStore(SQLBaseStore): } ) - for state in event.state_events: + for state in event.state_events.values(): self._simple_insert_txn( txn, table="state_groups_state", -- cgit 1.4.1 From 5ffe5ab43fa090111a0141b04ce6342172f60724 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Oct 2014 18:56:42 +0100 Subject: Use state groups to get current state. Make join dance actually work. --- synapse/api/auth.py | 5 +++ synapse/federation/replication.py | 17 +++++++- synapse/federation/transport.py | 57 +++++++++++++++++++++++--- synapse/handlers/federation.py | 74 +++++++++++++++++++++++---------- synapse/handlers/message.py | 6 +-- synapse/rest/base.py | 5 +++ synapse/rest/events.py | 34 ++++++++++------ synapse/state.py | 86 +++++++++++++++++++++++++++------------ synapse/storage/pdu.py | 6 +++ synapse/storage/state.py | 3 ++ 10 files changed, 226 insertions(+), 67 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d1eca791ab..50ce7eb4cd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -22,6 +22,7 @@ from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent, + RoomCreateEvent, ) from synapse.util.logutils import log_function @@ -59,6 +60,10 @@ class Auth(object): is_state = hasattr(event, "state_key") + if event.type == RoomCreateEvent.TYPE: + # FIXME + defer.returnValue(True) + if event.type == RoomMemberEvent.TYPE: yield self._can_replace_state(event) allowed = yield self.is_membership_change_allowed(event) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index d482193851..8c7d510ef6 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -403,11 +403,18 @@ class ReplicationLayer(object): defer.returnValue( (404, "No handler for Query type '%s'" % (query_type, )) ) + @defer.inlineCallbacks def on_make_join_request(self, context, user_id): pdu = yield self.handler.on_make_join_request(context, user_id) defer.returnValue(pdu.get_dict()) + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = Pdu(**content) + ret_pdu = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, ret_pdu.get_dict())) + @defer.inlineCallbacks def on_send_join_request(self, origin, content): pdu = Pdu(**content) @@ -426,8 +433,9 @@ class ReplicationLayer(object): defer.returnValue(Pdu(**pdu_dict)) + @defer.inlineCallbacks def send_join(self, destination, pdu): - return self.transport_layer.send_join( + _, content = yield self.transport_layer.send_join( destination, pdu.context, pdu.pdu_id, @@ -435,6 +443,13 @@ class ReplicationLayer(object): pdu.get_dict(), ) + logger.debug("Got content: %s", content) + pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])] + for pdu in pdus: + yield self._handle_new_pdu(destination, pdu) + + defer.returnValue(pdus) + @defer.inlineCallbacks @log_function def _get_persisted_pdu(self, pdu_id, pdu_origin): diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index a0d34fd24d..de64702e2f 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -229,13 +229,36 @@ class TransportLayer(object): pdu_id, ) - response = yield self.client.put_json( + code, content = yield self.client.put_json( destination=destination, path=path, data=content, ) - defer.returnValue(response) + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, context, pdu_id, origin, content): + path = PREFIX + "/invite/%s/%s/%s" % ( + context, + origin, + pdu_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) @defer.inlineCallbacks def _authenticate_request(self, request): @@ -297,9 +320,13 @@ class TransportLayer(object): @defer.inlineCallbacks def new_handler(request, *args, **kwargs): (origin, content) = yield self._authenticate_request(request) - response = yield handler( - origin, content, request.args, *args, **kwargs - ) + try: + response = yield handler( + origin, content, request.args, *args, **kwargs + ) + except: + logger.exception("Callback failed") + raise defer.returnValue(response) return new_handler @@ -419,6 +446,17 @@ class TransportLayer(object): ) ) + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, pdu_origin, pdu_id: + self._on_invite_request( + origin, content, query, + ) + ) + ) + @defer.inlineCallbacks @log_function def _on_send_request(self, origin, content, query, transaction_id): @@ -524,6 +562,15 @@ class TransportLayer(object): defer.returnValue((200, content)) + @defer.inlineCallbacks + @log_function + def _on_invite_request(self, origin, content, query): + content = yield self.request_handler.on_invite_request( + origin, content, + ) + + defer.returnValue((200, content)) + class TransportReceivedHandler(object): """ Callbacks used when we receive a transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0ae0541bd3..70790aaa72 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -62,6 +62,9 @@ class FederationHandler(BaseHandler): self.pdu_codec = PduCodec(hs) + # When joining a room we need to queue any events for that room up + self.room_queues = {} + @log_function @defer.inlineCallbacks def handle_new_event(self, event, snapshot): @@ -95,22 +98,25 @@ class FederationHandler(BaseHandler): logger.debug("Got event: %s", event.event_id) + if event.room_id in self.room_queues: + self.room_queues[event.room_id].append(pdu) + return + if state: state = [self.pdu_codec.event_from_pdu(p) for p in state] state = {(e.type, e.state_key): e for e in state} - yield self.state_handler.annotate_state_groups(event, state=state) + + is_new_state = yield self.state_handler.annotate_state_groups( + event, + state=state + ) logger.debug("Event: %s", event) if not backfilled: yield self.auth.check(event, None, raises=True) - if event.is_state and not backfilled: - is_new_state = yield self.state_handler.handle_new_state( - pdu - ) - else: - is_new_state = False + is_new_state = is_new_state and not backfilled # TODO: Implement something in federation that allows us to # respond to PDU. @@ -211,6 +217,8 @@ class FederationHandler(BaseHandler): assert(event.state_key == joinee) assert(event.room_id == room_id) + self.room_queues[room_id] = [] + event.event_id = self.event_factory.create_event_id() event.content = content @@ -219,15 +227,14 @@ class FederationHandler(BaseHandler): self.pdu_codec.pdu_from_event(event) ) - # TODO (erikj): Time out here. - d = defer.Deferred() - self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d) - reactor.callLater(10, d.cancel) + state = [self.pdu_codec.event_from_pdu(p) for p in state] - try: - yield d - except defer.CancelledError: - raise SynapseError(500, "Unable to join remote room") + logger.debug("do_invite_join state: %s", state) + + is_new_state = yield self.state_handler.annotate_state_groups( + event, + state=state + ) try: yield self.store.store_room( @@ -239,6 +246,32 @@ class FederationHandler(BaseHandler): # FIXME pass + for e in state: + # FIXME: Auth these. + is_new_state = yield self.state_handler.annotate_state_groups( + e, + state=state + ) + + yield self.store.persist_event( + e, + backfilled=False, + is_new_state=False + ) + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state + ) + + room_queue = self.room_queues[room_id] + del self.room_queues[room_id] + + for p in room_queue: + p.outlier = True + yield self.on_receive_pdu(p, backfilled=False) + defer.returnValue(True) @defer.inlineCallbacks @@ -264,13 +297,9 @@ class FederationHandler(BaseHandler): def on_send_join_request(self, origin, pdu): event = self.pdu_codec.event_from_pdu(pdu) - yield self.state_handler.annotate_state_groups(event) + is_new_state= yield self.state_handler.annotate_state_groups(event) yield self.auth.check(event, None, raises=True) - is_new_state = yield self.state_handler.handle_new_state( - pdu - ) - # FIXME (erikj): All this is duplicated above :( yield self.store.persist_event( @@ -303,7 +332,10 @@ class FederationHandler(BaseHandler): yield self.replication_layer.send_pdu(new_pdu) - defer.returnValue(event.state_events.values()) + defer.returnValue([ + self.pdu_codec.pdu_from_event(e) + for e in event.state_events.values() + ]) @defer.inlineCallbacks def get_state_for_pdu(self, pdu_id, pdu_origin): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1c2cbce151..4aaf97a83e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -199,7 +199,7 @@ class MessageHandler(BaseHandler): raise RoomError( 403, "Member does not meet private room rules.") - data = yield self.store.get_current_state( + data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) defer.returnValue(data) @@ -238,7 +238,7 @@ class MessageHandler(BaseHandler): yield self.auth.check_joined_room(room_id, user_id) # TODO: This is duplicating logic from snapshot_all_rooms - current_state = yield self.store.get_current_state(room_id) + current_state = yield self.state_handler.get_current_state(room_id) defer.returnValue([self.hs.serialize_event(c) for c in current_state]) @defer.inlineCallbacks @@ -315,7 +315,7 @@ class MessageHandler(BaseHandler): "end": end_token.to_string(), } - current_state = yield self.store.get_current_state( + current_state = yield self.state_handler.get_current_state( event.room_id ) d["state"] = [self.hs.serialize_event(c) for c in current_state] diff --git a/synapse/rest/base.py b/synapse/rest/base.py index 2e8e3fa7d4..dc784c1527 100644 --- a/synapse/rest/base.py +++ b/synapse/rest/base.py @@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX from synapse.rest.transactions import HttpTransactionStore import re +import logging + + +logger = logging.getLogger(__name__) + def client_path_pattern(path_regex): """Creates a regex compiled client path with the correct client path diff --git a/synapse/rest/events.py b/synapse/rest/events.py index 097195d7cc..92ff5e5ca7 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/events.py @@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig from synapse.rest.base import RestServlet, client_path_pattern +import logging + + +logger = logging.getLogger(__name__) + + class EventStreamRestServlet(RestServlet): PATTERN = client_path_pattern("/events$") @@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): auth_user = yield self.auth.get_user_by_req(request) - - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - chunk = yield handler.get_stream(auth_user.to_string(), pagin_config, - timeout=timeout) + try: + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + chunk = yield handler.get_stream( + auth_user.to_string(), pagin_config, timeout=timeout + ) + except: + logger.exception("Event stream failed") + raise defer.returnValue((200, chunk)) diff --git a/synapse/state.py b/synapse/state.py index 8c4eeb8924..24685c6fb4 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.federation.pdu_codec import encode_event_id, decode_event_id from synapse.util.logutils import log_function +from synapse.federation.pdu_codec import encode_event_id from collections import namedtuple @@ -130,54 +131,89 @@ class StateHandler(object): defer.returnValue(is_new) @defer.inlineCallbacks + @log_function def annotate_state_groups(self, event, state=None): if state: event.state_group = None event.old_state_events = None - event.state_events = state + event.state_events = {(s.type, s.state_key): s for s in state} + defer.returnValue(False) + return + + if hasattr(event, "outlier") and event.outlier: + event.state_group = None + event.old_state_events = None + event.state_events = None + defer.returnValue(False) return + new_state = yield self.resolve_state_groups(event.prev_events) + + event.old_state_events = new_state + + if hasattr(event, "state_key"): + new_state[(event.type, event.state_key)] = event + + event.state_group = None + event.state_events = new_state + + defer.returnValue(hasattr(event, "state_key")) + + @defer.inlineCallbacks + def get_current_state(self, room_id, event_type=None, state_key=""): + # FIXME: HACK! + pdus = yield self.store.get_latest_pdus_in_context(room_id) + + event_ids = [encode_event_id(p.pdu_id, p.origin) for p in pdus] + + res = self.resolve_state_groups(event_ids) + + if event_type: + defer.returnValue(res.get((event_type, state_key))) + return + + defer.returnValue(res.values()) + + @defer.inlineCallbacks + @log_function + def resolve_state_groups(self, event_ids): state_groups = yield self.store.get_state_groups( - event.prev_events + event_ids ) state = {} - state_sets = {} for group in state_groups: for s in group.state: - state.setdefault((s.type, s.state_key), []).append(s) - - state_sets.setdefault( + state.setdefault( (s.type, s.state_key), - set() - ).add(s.event_id) + {} + )[s.event_id] = s unconflicted_state = { - k: state[k].pop() for k, v in state_sets.items() - if len(v) == 1 + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 } conflicted_state = { - k: state[k] - for k, v in state_sets.items() - if len(v) > 1 + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 } - new_state = {} - new_state.update(unconflicted_state) - for key, events in conflicted_state.items(): - new_state[key] = yield self.resolve(events) + try: + new_state = {} + new_state.update(unconflicted_state) + for key, events in conflicted_state.items(): + new_state[key] = yield self._resolve_state_events(events) + except: + logger.exception("Failed to resolve state") + raise - event.old_state_events = new_state - - if hasattr(event, "state_key"): - new_state[(event.type, event.state_key)] = event - - event.state_group = None - event.state_events = new_state + defer.returnValue(new_state) @defer.inlineCallbacks - def resolve(self, events): + @log_function + def _resolve_state_events(self, events): curr_events = events new_powers_deferreds = [] diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index d70467dcd6..b1cb0185a6 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -277,6 +277,12 @@ class PduStore(SQLBaseStore): (context, depth) ) + def get_latest_pdus_in_context(self, context): + return self.runInteraction( + self._get_latest_pdus_in_context, + context + ) + def _get_latest_pdus_in_context(self, txn, context): """Get's a list of the most current pdus for a given context. This is used when we are sending a Pdu and need to fill out the `prev_pdus` diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 9496c935a7..0aa979c9f0 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -63,6 +63,9 @@ class StateStore(SQLBaseStore): ) def _store_state_groups_txn(self, txn, event): + if not event.state_events: + return + state_group = event.state_group if not state_group: state_group = self._simple_insert_txn( -- cgit 1.4.1 From da1dda3e1d9d3272527d35c23162c4baf7339d74 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 Oct 2014 11:18:04 +0000 Subject: Add transaction level logging and timing information. Add a _simple_delete method --- synapse/storage/__init__.py | 3 +- synapse/storage/_base.py | 74 ++++++++++++++++++++++++++++++++--------- synapse/storage/directory.py | 1 + synapse/storage/pdu.py | 13 +++++++- synapse/storage/registration.py | 7 ++-- synapse/storage/room.py | 2 ++ synapse/storage/state.py | 1 + synapse/storage/stream.py | 5 ++- synapse/storage/transactions.py | 6 ++++ 9 files changed, 91 insertions(+), 21 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 15a72d0cd7..a50e19349a 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -109,6 +109,7 @@ class DataStore(RoomMemberStore, RoomStore, try: yield self.runInteraction( + "persist_event", self._persist_pdu_event_txn, pdu=pdu, event=event, @@ -394,7 +395,7 @@ class DataStore(RoomMemberStore, RoomStore, prev_state_pdu=prev_state_pdu, ) - return self.runInteraction(_snapshot) + return self.runInteraction("snapshot_room", _snapshot) class Snapshot(object): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index d3e8741889..1192216971 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -29,15 +29,17 @@ import time logger = logging.getLogger(__name__) sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging to the .execute() method.""" - __slots__ = ["txn"] + __slots__ = ["txn", "name"] - def __init__(self, txn): + def __init__(self, txn, name): object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) def __getattr__(self, name): return getattr(self.txn, name) @@ -47,12 +49,15 @@ class LoggingTransaction(object): def execute(self, sql, *args, **kwargs): # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] %s", sql) + sql_logger.debug("[SQL] {%s} %s", self.name, sql) try: if args and args[0]: values = args[0] - sql_logger.debug("[SQL values] " + - ", ".join(("<%s>",) * len(values)), *values) + sql_logger.debug( + "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), + self.name, + *values + ) except: # Don't let logging failures stop SQL from working pass @@ -64,10 +69,11 @@ class LoggingTransaction(object): ) finally: end = time.clock() * 1000 - sql_logger.debug("[SQL time] %f", end - start) + sql_logger.debug("[SQL time] {%s} %f", self.name, end - start) class SQLBaseStore(object): + _TXN_ID = 0 def __init__(self, hs): self.hs = hs @@ -75,10 +81,24 @@ class SQLBaseStore(object): self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() - def runInteraction(self, func, *args, **kwargs): + def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" def inner_func(txn, *args, **kwargs): - return func(LoggingTransaction(txn), *args, **kwargs) + start = time.clock() * 1000 + txn_id = str(SQLBaseStore._TXN_ID) + SQLBaseStore._TXN_ID += 1 + + name = "%s-%s" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + try: + return func(LoggingTransaction(txn, name), *args, **kwargs) + finally: + end = time.clock() * 1000 + transaction_logger.debug( + "[TXN END] {%s} %f", + name, end - start + ) return self._db_pool.runInteraction(inner_func, *args, **kwargs) @@ -114,7 +134,7 @@ class SQLBaseStore(object): else: return cursor.fetchall() - return self.runInteraction(interaction) + return self.runInteraction("_execute", interaction) def _execute_and_decode(self, query, *args): return self._execute(self.cursor_to_dict, query, *args) @@ -131,6 +151,7 @@ class SQLBaseStore(object): or_replace : bool; if True performs an INSERT OR REPLACE """ return self.runInteraction( + "_simple_insert", self._simple_insert_txn, table, values, or_replace=or_replace, or_ignore=or_ignore, ) @@ -168,6 +189,7 @@ class SQLBaseStore(object): statement returns no rows """ return self._simple_selectupdate_one( + "_simple_select_one", table, keyvalues, retcols=retcols, allow_none=allow_none ) @@ -217,7 +239,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return txn.fetchall() - res = yield self.runInteraction(func) + res = yield self.runInteraction("_simple_select_onecol", func) defer.returnValue([r[0] for r in res]) @@ -240,7 +262,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn) - return self.runInteraction(func) + return self.runInteraction("_simple_select_list", func) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -308,7 +330,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched") return ret - return self.runInteraction(func) + return self.runInteraction("_simple_selectupdate_one", func) def _simple_delete_one(self, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a @@ -320,7 +342,7 @@ class SQLBaseStore(object): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) def func(txn): @@ -329,7 +351,25 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") - return self.runInteraction(func) + return self.runInteraction("_simple_delete_one", func) + + def _simple_delete(self, table, keyvalues): + """Executes a DELETE query on the named table. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + + return self.runInteraction("_simple_delete", self._simple_delete_txn) + + def _simple_delete_txn(self, txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k, ) for k in keyvalues) + ) + + return txn.execute(sql, keyvalues.values()) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the @@ -347,7 +387,7 @@ class SQLBaseStore(object): return 0 return max_id - return self.runInteraction(func) + return self.runInteraction("_simple_max_id", func) def _parse_event_from_row(self, row_dict): d = copy.deepcopy({k: v for k, v in row_dict.items()}) @@ -371,7 +411,9 @@ class SQLBaseStore(object): ) def _parse_events(self, rows): - return self.runInteraction(self._parse_events_txn, rows) + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) def _parse_events_txn(self, txn, rows): events = [self._parse_event_from_row(r) for r in rows] diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 52373a28a6..d6a7113b9c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore): def delete_room_alias(self, room_alias): return self.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias, ) diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 9bdc831fd8..4a4341907b 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -47,7 +47,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin + "get_pdu", self._get_pdu_tuple, pdu_id, origin ) def _get_pdu_tuple(self, txn, pdu_id, origin): @@ -108,6 +108,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "get_current_state_for_context", self._get_current_state_for_context, context ) @@ -156,6 +157,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "mark_pdu_as_processed", self._mark_as_processed, pdu_id, pdu_origin ) @@ -165,6 +167,7 @@ class PduStore(SQLBaseStore): def get_all_pdus_from_context(self, context): """Get a list of all PDUs for a given context.""" return self.runInteraction( + "get_all_pdus_from_context", self._get_all_pdus_from_context, context, ) @@ -192,6 +195,7 @@ class PduStore(SQLBaseStore): list: A list of PduTuples """ return self.runInteraction( + "get_backfill", self._get_backfill, context, pdu_list, limit ) @@ -253,6 +257,7 @@ class PduStore(SQLBaseStore): context (str) """ return self.runInteraction( + "get_min_depth_for_context", self._get_min_depth_for_context, context ) @@ -291,6 +296,7 @@ class PduStore(SQLBaseStore): def get_latest_pdus_in_context(self, context): return self.runInteraction( + "get_latest_pdus_in_context", self._get_latest_pdus_in_context, context ) @@ -370,6 +376,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "is_pdu_new", self._is_pdu_new, pdu_id=pdu_id, origin=origin, @@ -523,6 +530,7 @@ class StatePduStore(SQLBaseStore): def get_unresolved_state_tree(self, new_state_pdu): return self.runInteraction( + "get_unresolved_state_tree", self._get_unresolved_state_tree, new_state_pdu ) @@ -562,6 +570,7 @@ class StatePduStore(SQLBaseStore): def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): return self.runInteraction( + "update_current_state", self._update_current_state, pdu_id, origin, context, pdu_type, state_key ) @@ -601,6 +610,7 @@ class StatePduStore(SQLBaseStore): """ return self.runInteraction( + "get_current_state_pdu", self._get_current_state_pdu, context, pdu_type, state_key ) @@ -660,6 +670,7 @@ class StatePduStore(SQLBaseStore): bool: True if the new_pdu clobbered the current state, False if not """ return self.runInteraction( + "handle_new_state", self._handle_new_state, new_pdu ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 719806f82b..a2ca6f9a69 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction(self._register, user_id, token, - password_hash) + yield self.runInteraction( + "register", + self._register, user_id, token, password_hash + ) def _register(self, txn, user_id, token, password_hash): now = int(self.clock.time()) @@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( + "get_user_by_token", self._query_for_auth, token ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8cd46334cf..7e48ce9cc3 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -150,6 +150,7 @@ class RoomStore(SQLBaseStore): def get_power_level(self, room_id, user_id): return self.runInteraction( + "get_power_level", self._get_power_level, room_id, user_id, ) @@ -183,6 +184,7 @@ class RoomStore(SQLBaseStore): def get_ops_levels(self, room_id): return self.runInteraction( + "get_ops_levels", self._get_ops_levels, room_id, ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0aa979c9f0..e08acd6404 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -59,6 +59,7 @@ class StateStore(SQLBaseStore): def store_state_groups(self, event): return self.runInteraction( + "store_state_groups", self._store_state_groups_txn, event ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d61f909939..8f7f61d29d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) def get_room_events_max_id(self): - return self.runInteraction(self._get_room_events_max_id_txn) + return self.runInteraction( + "get_room_events_max_id", + self._get_room_events_max_id_txn + ) def _get_room_events_max_id_txn(self, txn): txn.execute( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 2ba8e30efe..908014d38b 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -42,6 +42,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "get_received_txn_response", self._get_received_txn_response, transaction_id, origin ) @@ -73,6 +74,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "set_received_txn_response", self._set_received_txn_response, transaction_id, origin, code, response_dict ) @@ -106,6 +108,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "prep_send_transaction", self._prep_send_transaction, transaction_id, destination, origin_server_ts, pdu_list ) @@ -161,6 +164,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ return self.runInteraction( + "delivered_txn", self._delivered_txn, transaction_id, destination, code, response_dict ) @@ -186,6 +190,7 @@ class TransactionStore(SQLBaseStore): list: A list of `ReceivedTransactionsTable.EntryType` """ return self.runInteraction( + "get_transactions_after", self._get_transactions_after, transaction_id, destination ) @@ -216,6 +221,7 @@ class TransactionStore(SQLBaseStore): list: A list of PduTuple """ return self.runInteraction( + "get_pdus_after_transaction", self._get_pdus_after_transaction, transaction_id, destination ) -- cgit 1.4.1 From 5ff0bfb81dfbf27d5c889113d08edb609a2d145a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Nov 2014 14:16:41 +0000 Subject: Fix bug where we /always/ created a new state group --- synapse/handlers/federation.py | 3 ++- synapse/state.py | 61 ++++++++++++++++++++++++++---------------- synapse/storage/state.py | 9 ++----- 3 files changed, 42 insertions(+), 31 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e909af6bd8..c2cd91bb39 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -448,8 +448,9 @@ class FederationHandler(BaseHandler): ) if state_groups: + _, state = state_groups.items().pop() results = { - (e.type, e.state_key): e for e in state_groups[0].state + (e.type, e.state_key): e for e in state } event = yield self.store.get_event(event_id) diff --git a/synapse/state.py b/synapse/state.py index e2fd48bdae..a58acb3c1e 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -42,9 +42,6 @@ class StateHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - self._replication = hs.get_replication_layer() - self.server_name = hs.hostname - self.hs = hs @defer.inlineCallbacks @log_function @@ -71,9 +68,10 @@ class StateHandler(object): defer.returnValue(False) return - new_state = yield self.resolve_state_groups( - [e for e, _ in event.prev_events] - ) + ids = [e for e, _ in event.prev_events] + + ret = yield self.resolve_state_groups(ids) + state_group, new_state = ret event.old_state_events = copy.deepcopy(new_state) @@ -82,6 +80,10 @@ class StateHandler(object): if key in new_state: event.replaces_state = new_state[key].event_id new_state[key] = event + elif state_group: + event.state_group = state_group + event.state_events = new_state + defer.returnValue(False) event.state_group = None event.state_events = new_state @@ -100,7 +102,7 @@ class StateHandler(object): res = yield self.resolve_state_groups(event_ids) if event_type: - defer.returnValue(res.get((event_type, state_key))) + defer.returnValue(res[1].get((event_type, state_key))) return defer.returnValue(res.values()) @@ -112,9 +114,18 @@ class StateHandler(object): event_ids ) + group_names = set(state_groups.keys()) + if len(group_names) == 1: + name, state_list = state_groups.items().pop() + state = { + (e.type, e.state_key): e + for e in state_list + } + defer.returnValue((name, state)) + state = {} - for group in state_groups: - for s in group.state: + for group, g_state in state_groups.items(): + for s in g_state: state.setdefault( (s.type, s.state_key), {} @@ -135,25 +146,29 @@ class StateHandler(object): new_state = {} new_state.update(unconflicted_state) for key, events in conflicted_state.items(): - new_state[key] = yield self._resolve_state_events(events) + new_state[key] = self._resolve_state_events(events) except: logger.exception("Failed to resolve state") raise - defer.returnValue(new_state) + defer.returnValue((None, new_state)) def _get_power_level_from_event_state(self, event, user_id): - key = (RoomPowerLevelsEvent.TYPE, "", ) - power_level_event = event.old_state_events.get(key) - level = None - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) - - return level + if hasattr(event, "old_state_events") and event.old_state_events: + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + level = None + if power_level_event: + level = power_level_event.content.get("users", {}).get( + user_id + ) + if not level: + level = power_level_event.content.get("users_default", 0) + + return level + else: + return 0 - @defer.inlineCallbacks @log_function def _resolve_state_events(self, events): curr_events = events @@ -177,10 +192,10 @@ class StateHandler(object): if not curr_events: raise RuntimeError("Max didn't get a max?") elif len(curr_events) == 1: - defer.returnValue(curr_events[0]) + return curr_events[0] # TODO: For now, just choose the one with the largest event_id. - defer.returnValue( + return ( sorted( curr_events, key=lambda e: hashlib.sha1( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e08acd6404..68975969f5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -16,11 +16,6 @@ from ._base import SQLBaseStore from twisted.internet import defer -from collections import namedtuple - - -StateGroup = namedtuple("StateGroup", ("group", "state")) - class StateStore(SQLBaseStore): @@ -37,7 +32,7 @@ class StateStore(SQLBaseStore): if group: groups.add(group) - res = [] + res = {} for group in groups: state_ids = yield self._simple_select_onecol( table="state_groups_state", @@ -53,7 +48,7 @@ class StateStore(SQLBaseStore): if s: state.append(s) - res.append(StateGroup(group, state)) + res[group] = state defer.returnValue(res) -- cgit 1.4.1