diff options
-rw-r--r-- | synapse/api/auth.py | 33 | ||||
-rw-r--r-- | synapse/events/__init__.py | 35 | ||||
-rw-r--r-- | synapse/events/utils.py | 16 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 164 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 19 | ||||
-rw-r--r-- | synapse/server.py | 2 | ||||
-rw-r--r-- | synapse/state.py | 33 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 23 | ||||
-rw-r--r-- | synapse/storage/_base.py | 8 | ||||
-rw-r--r-- | synapse/storage/state.py | 13 |
10 files changed, 212 insertions, 134 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5261c3e3bf..3f2e58a5ef 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -351,27 +351,27 @@ class Auth(object): return self.store.is_server_admin(user) @defer.inlineCallbacks - def get_auth_events(self, event, current_state): - if event.type == RoomCreateEvent.TYPE: - event.auth_events = [] + def add_auth_events(self, builder, context): + if builder.type == RoomCreateEvent.TYPE: + builder.auth_events = [] return auth_events = [] key = (RoomPowerLevelsEvent.TYPE, "", ) - power_level_event = current_state.get(key) + power_level_event = context.current_state.get(key) if power_level_event: auth_events.append(power_level_event.event_id) key = (RoomJoinRulesEvent.TYPE, "", ) - join_rule_event = current_state.get(key) + join_rule_event = context.current_state.get(key) - key = (RoomMemberEvent.TYPE, event.user_id, ) - member_event = current_state.get(key) + key = (RoomMemberEvent.TYPE, builder.user_id, ) + member_event = context.current_state.get(key) key = (RoomCreateEvent.TYPE, "", ) - create_event = current_state.get(key) + create_event = context.current_state.get(key) if create_event: auth_events.append(create_event.event_id) @@ -381,8 +381,8 @@ class Auth(object): else: is_public = False - if event.type == RoomMemberEvent.TYPE: - e_type = event.content["membership"] + if builder.type == RoomMemberEvent.TYPE: + e_type = builder.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: if join_rule_event: auth_events.append(join_rule_event.event_id) @@ -393,11 +393,18 @@ class Auth(object): if member_event.content["membership"] == Membership.JOIN: auth_events.append(member_event.event_id) - auth_events = yield self.store.add_event_hashes( - auth_events + auth_ids = [(a.event_id, h) for a, h in auth_events] + auth_events_entries = yield self.store.add_event_hashes( + auth_ids ) - defer.returnValue(auth_events) + builder.auth_events = auth_events_entries + + context.auth_events = { + k: v + for k, v in context.current_state.items() + if v.event_id in auth_ids + } @log_function def _can_send_event(self, event, auth_events): diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 58edf2bc8f..e81b995d39 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -17,8 +17,8 @@ from frozendict import frozendict def _freeze(o): - if isinstance(o, dict): - return frozendict({k: _freeze(v) for k,v in o.items()}) + if isinstance(o, dict) or isinstance(o, frozendict): + return frozendict({k: _freeze(v) for k, v in o.items()}) if isinstance(o, basestring): return o @@ -31,6 +31,21 @@ def _freeze(o): return o +def _unfreeze(o): + if isinstance(o, frozendict) or isinstance(o, dict): + return dict({k: _unfreeze(v) for k, v in o.items()}) + + if isinstance(o, basestring): + return o + + try: + return [_unfreeze(i) for i in o] + except TypeError: + pass + + return o + + class _EventInternalMetadata(object): def __init__(self, internal_metadata_dict): self.__dict__ = internal_metadata_dict @@ -69,6 +84,7 @@ class EventBase(object): ) auth_events = _event_dict_property("auth_events") + depth = _event_dict_property("depth") content = _event_dict_property("content") event_id = _event_dict_property("event_id") hashes = _event_dict_property("hashes") @@ -81,6 +97,10 @@ class EventBase(object): type = _event_dict_property("type") user_id = _event_dict_property("sender") + @property + def membership(self): + return self.content["membership"] + def is_state(self): return hasattr(self, "state_key") @@ -134,3 +154,14 @@ class FrozenEvent(EventBase): e.internal_metadata = event.internal_metadata return e + + def get_dict(self): + # We need to unfreeze what we return + + d = _unfreeze(self._event_dict) + d.update({ + "signatures": self.signatures, + "unsigned": self.unsigned, + }) + + return d diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 412f690f08..1b05ee0a95 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api.constants import EventTypes +from . import EventBase def prune_event(event): @@ -80,3 +81,18 @@ def prune_event(event): allowed_fields["content"] = new_content return type(event)(allowed_fields) + + +def serialize_event(hs, e): + # FIXME(erikj): To handle the case of presence events and the like + if not isinstance(e, EventBase): + return e + + # Should this strip out None's? + d = {k: v for k, v in e.get_dict().items()} + if "age_ts" in d["unsigned"]: + now = int(hs.get_clock().time_msec()) + d["unsigned"]["age"] = now - d["unsigned"]["age_ts"] + del d["unsigned"]["age_ts"] + + return d \ No newline at end of file diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 4052d0e1e7..810ce138ff 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -62,6 +62,8 @@ class BaseHandler(object): @defer.inlineCallbacks def _create_new_client_event(self, builder): + context = EventContext() + latest_ret = yield self.store.get_latest_events_in_room( builder.room_id, ) @@ -69,34 +71,26 @@ class BaseHandler(object): depth = max([d for _, _, d in latest_ret]) prev_events = [(e, h) for e, h, _ in latest_ret] - state_handler = self.state_handler - if builder.is_state(): - ret = yield state_handler.resolve_state_groups( - [e for e, _ in prev_events], - event_type=builder.event_type, - state_key=builder.state_key, - ) + builder.prev_events = prev_events + builder.depth = depth - group, curr_state, prev_state = ret + state_handler = self.state_handler + ret = yield state_handler.annotate_context_with_state( + builder, + context, + ) + group, prev_state = ret + if builder.is_state(): prev_state = yield self.store.add_event_hashes( prev_state ) builder.prev_state = prev_state - else: - group, curr_state, _ = yield state_handler.resolve_state_groups( - [e for e, _ in prev_events], - ) builder.internal_metadata.state_group = group - builder.prev_events = prev_events - builder.depth = depth - - auth_events = yield self.auth.get_auth_events(builder, curr_state) - - builder.update_event_key("auth_events", auth_events) + yield self.auth.add_auth_events(builder, context) add_hashes_and_signatures( builder, self.server_name, self.signing_key @@ -104,18 +98,6 @@ class BaseHandler(object): event = builder.build() - auth_ids = zip(*auth_events)[0] - curr_auth_events = { - k: v - for k, v in curr_state.items() - if v.event_id in auth_ids - } - - context = EventContext( - current_state=curr_state, - auth_events=curr_auth_events, - ) - defer.returnValue( (event, context,) ) @@ -128,7 +110,7 @@ class BaseHandler(object): if not suppress_auth: self.auth.check(event, auth_events=context.auth_events) - yield self.store.persist_event(event) + yield self.store.persist_event(event, context=context) destinations = set(extra_destinations) for k, s in context.current_state.items(): @@ -152,63 +134,63 @@ class BaseHandler(object): destinations=destinations, ) - @defer.inlineCallbacks - def _on_new_room_event(self, event, snapshot, extra_destinations=[], - extra_users=[], suppress_auth=False, - do_invite_host=None): - yield run_on_reactor() - - snapshot.fill_out_prev_events(event) - - yield self.state_handler.annotate_event_with_state(event) - - yield self.auth.add_auth_events(event) - - logger.debug("Signing event...") - - add_hashes_and_signatures( - event, self.server_name, self.signing_key - ) - - logger.debug("Signed event.") - - if not suppress_auth: - logger.debug("Authing...") - self.auth.check(event, auth_events=event.old_state_events) - logger.debug("Authed") - else: - logger.debug("Suppressed auth.") - - if do_invite_host: - federation_handler = self.hs.get_handlers().federation_handler - invite_event = yield federation_handler.send_invite( - do_invite_host, - event - ) - - # FIXME: We need to check if the remote changed anything else - event.signatures = invite_event.signatures - - yield self.store.persist_event(event) - - destinations = set(extra_destinations) - # Send a PDU to all hosts who have joined the room. - - for k, s in event.state_events.items(): - try: - if k[0] == RoomMemberEvent.TYPE: - if s.content["membership"] == Membership.JOIN: - destinations.add( - self.hs.parse_userid(s.state_key).domain - ) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - event.destinations = list(destinations) - - yield self.notifier.on_new_room_event(event, extra_users=extra_users) - - federation_handler = self.hs.get_handlers().federation_handler - yield federation_handler.handle_new_event(event, snapshot) + # @defer.inlineCallbacks + # def _on_new_room_event(self, event, snapshot, extra_destinations=[], + # extra_users=[], suppress_auth=False, + # do_invite_host=None): + # yield run_on_reactor() + # + # snapshot.fill_out_prev_events(event) + # + # yield self.state_handler.annotate_event_with_state(event) + # + # yield self.auth.add_auth_events(event) + # + # logger.debug("Signing event...") + # + # add_hashes_and_signatures( + # event, self.server_name, self.signing_key + # ) + # + # logger.debug("Signed event.") + # + # if not suppress_auth: + # logger.debug("Authing...") + # self.auth.check(event, auth_events=event.old_state_events) + # logger.debug("Authed") + # else: + # logger.debug("Suppressed auth.") + # + # if do_invite_host: + # federation_handler = self.hs.get_handlers().federation_handler + # invite_event = yield federation_handler.send_invite( + # do_invite_host, + # event + # ) + # + # # FIXME: We need to check if the remote changed anything else + # event.signatures = invite_event.signatures + # + # yield self.store.persist_event(event) + # + # destinations = set(extra_destinations) + # # Send a PDU to all hosts who have joined the room. + # + # for k, s in event.state_events.items(): + # try: + # if k[0] == RoomMemberEvent.TYPE: + # if s.content["membership"] == Membership.JOIN: + # destinations.add( + # self.hs.parse_userid(s.state_key).domain + # ) + # except: + # logger.warn( + # "Failed to get destination from event %s", s.event_id + # ) + # + # event.destinations = list(destinations) + # + # yield self.notifier.on_new_room_event(event, extra_users=extra_users) + # + # federation_handler = self.hs.get_handlers().federation_handler + # yield federation_handler.handle_new_event(event, snapshot) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b4a28ea3cb..5264e3eafc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,7 +17,8 @@ from ._base import BaseHandler -from synapse.api.events.utils import prune_event +from synapse.events.snapshot import EventContext +from synapse.events.utils import prune_event from synapse.api.errors import ( AuthError, FederationError, SynapseError, StoreError, ) @@ -416,7 +417,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def on_make_join_request(self, context, user_id): + def on_make_join_request(self, room_id, user_id): """ We've received a /make_join/ request, so we create a partial join event for the room and return that. We don *not* persist or process it until the other server has signed it and sent it back. @@ -424,7 +425,7 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new({ "type": RoomMemberEvent.TYPE, "content": {"membership": Membership.JOIN}, - "room_id": context, + "room_id": room_id, "sender": user_id, "state_key": user_id, }) @@ -433,9 +434,7 @@ class FederationHandler(BaseHandler): builder=builder, ) - yield self.state_handler.annotate_event_with_state(event) - yield self.auth.add_auth_events(event) - self.auth.check(event, auth_events=event.old_state_events) + self.auth.check(event, auth_events=context.auth_events) pdu = event @@ -505,7 +504,9 @@ class FederationHandler(BaseHandler): """ event = pdu - event.outlier = True + context = EventContext() + + event.internal_metadata.outlier = True event.signatures.update( compute_event_signature( @@ -515,10 +516,11 @@ class FederationHandler(BaseHandler): ) ) - yield self.state_handler.annotate_event_with_state(event) + yield self.state_handler.annotate_context_with_state(event, context) yield self.store.persist_event( event, + context=context, backfilled=False, ) @@ -640,6 +642,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _handle_new_event(self, event, state=None, backfilled=False, current_state=None, fetch_missing=True): + context = EventContext() is_new_state = yield self.state_handler.annotate_event_with_state( event, old_state=state diff --git a/synapse/server.py b/synapse/server.py index 8bc27bbc3c..0d0f3af3f4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,7 +20,7 @@ # Imports required for the default HomeServer() implementation from synapse.federation import initialize_http_replication -from synapse.api.events import serialize_event +from synapse.events.utils import serialize_event from synapse.api.events.factory import EventFactory from synapse.api.events.validator import EventValidator from synapse.notifier import Notifier diff --git a/synapse/state.py b/synapse/state.py index 8a556a27f6..cbb4243fad 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -136,6 +136,39 @@ class StateHandler(object): defer.returnValue(res[1].values()) @defer.inlineCallbacks + def annotate_context_with_state(self, event, context): + if event.is_state(): + ret = yield self.resolve_state_groups( + [e for e, _ in event.prev_events], + event_type=event.event_type, + state_key=event.state_key, + ) + else: + ret = yield self.resolve_state_groups( + [e for e, _ in event.prev_events], + ) + + group, curr_state, prev_state = ret + + context.current_state = curr_state + + prev_state = yield self.store.add_event_hashes( + prev_state + ) + + if hasattr(event, "auth_events") and event.auth_events: + auth_ids = zip(*event.auth_events)[0] + context.auth_events = { + k: v + for k, v in context.current_state.items() + if v.event_id in auth_ids + } + + defer.returnValue( + (group, prev_state) + ) + + @defer.inlineCallbacks @log_function def resolve_state_groups(self, event_ids, event_type=None, state_key=""): """ Given a list of event_ids this method fetches the state at each diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 205d125642..f172c2690a 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -21,6 +21,7 @@ from synapse.api.events.room import ( ) from synapse.util.logutils import log_function +from synapse.util.frozenutils import FrozenEncoder from .directory import DirectoryStore from .feedback import FeedbackStore @@ -93,8 +94,8 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks @log_function - def persist_event(self, event, backfilled=False, is_new_state=True, - current_state=None): + def persist_event(self, event, context, backfilled=False, + is_new_state=True, current_state=None): stream_ordering = None if backfilled: if not self.min_token_deferred.called: @@ -107,6 +108,7 @@ class DataStore(RoomMemberStore, RoomStore, "persist_event", self._persist_event_txn, event=event, + context=context, backfilled=backfilled, stream_ordering=stream_ordering, is_new_state=is_new_state, @@ -138,8 +140,9 @@ class DataStore(RoomMemberStore, RoomStore, defer.returnValue(event[0]) @log_function - def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, - is_new_state=True, current_state=None): + def _persist_event_txn(self, txn, event, context, backfilled, + stream_ordering=None, is_new_state=True, + current_state=None): if event.type == RoomMemberEvent.TYPE: self._store_room_member_txn(txn, event) elif event.type == FeedbackEvent.TYPE: @@ -152,12 +155,12 @@ class DataStore(RoomMemberStore, RoomStore, self._store_redaction(txn, event) outlier = False - if hasattr(event, "outlier"): - outlier = event.outlier + if hasattr(event.internal_metadata, "outlier"): + outlier = event.internal_metadata.outlier event_dict = { k: v - for k, v in event.get_full_dict().items() + for k, v in event.get_dict().items() if k not in [ "redacted", "redacted_because", @@ -179,7 +182,7 @@ class DataStore(RoomMemberStore, RoomStore, "event_id": event.event_id, "type": event.type, "room_id": event.room_id, - "content": json.dumps(event.content), + "content": json.dumps(event.content, cls=FrozenEncoder), "processed": True, "outlier": outlier, "depth": event.depth, @@ -190,7 +193,7 @@ class DataStore(RoomMemberStore, RoomStore, unrec = { k: v - for k, v in event.get_full_dict().items() + for k, v in event.get_dict().items() if k not in vals.keys() and k not in [ "redacted", "redacted_because", @@ -225,7 +228,7 @@ class DataStore(RoomMemberStore, RoomStore, room_id=event.room_id, ) - self._store_state_groups_txn(txn, event) + self._store_state_groups_txn(txn, event, context) if current_state: txn.execute( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bb61c20150..c56c3a0b0f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,7 +15,8 @@ import logging from synapse.api.errors import StoreError -from synapse.api.events.utils import prune_event +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from syutil.base64util import encode_base64 @@ -497,10 +498,7 @@ class SQLBaseStore(object): d = json.loads(js) - ev = self.event_factory.create_event( - etype=d["type"], - **d - ) + ev = FrozenEvent(d) if hasattr(ev, "redacted") and ev.redacted: # Get the redaction event. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e0f44b3e59..b8e721ad72 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -86,11 +86,16 @@ class StateStore(SQLBaseStore): self._store_state_groups_txn, event ) - def _store_state_groups_txn(self, txn, event): - if event.state_events is None: + def _store_state_groups_txn(self, txn, event, context): + if context.current_state_events is None: return - state_group = event.state_group + state_events = context.current_state_events + + if event.is_state(): + state_events[(event.type, event.state_key)] = event + + state_group = context.state_group if not state_group: state_group = self._simple_insert_txn( txn, @@ -102,7 +107,7 @@ class StateStore(SQLBaseStore): or_ignore=True, ) - for state in event.state_events.values(): + for state in context.state_events.values(): self._simple_insert_txn( txn, table="state_groups_state", |