diff options
-rw-r--r-- | synapse/api/auth.py | 18 | ||||
-rwxr-xr-x | synapse/app/homeserver.py | 4 | ||||
-rw-r--r-- | synapse/federation/handler.py | 157 | ||||
-rw-r--r-- | synapse/federation/persistence.py | 64 | ||||
-rw-r--r-- | synapse/federation/replication.py | 6 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 24 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 99 | ||||
-rw-r--r-- | synapse/handlers/room.py | 123 | ||||
-rw-r--r-- | synapse/server.py | 5 | ||||
-rw-r--r-- | synapse/state.py | 18 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 170 | ||||
-rw-r--r-- | synapse/storage/_base.py | 12 | ||||
-rw-r--r-- | synapse/storage/feedback.py | 4 | ||||
-rw-r--r-- | synapse/storage/pdu.py | 29 | ||||
-rw-r--r-- | synapse/storage/room.py | 10 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 28 | ||||
-rw-r--r-- | synapse/storage/stream.py | 13 | ||||
-rw-r--r-- | tests/federation/test_federation.py | 2 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 43 | ||||
-rw-r--r-- | tests/handlers/test_room.py | 51 | ||||
-rw-r--r-- | tests/rest/test_events.py | 5 | ||||
-rw-r--r-- | tests/rest/test_rooms.py | 12 | ||||
-rw-r--r-- | tests/test_state.py | 17 | ||||
-rw-r--r-- | tests/utils.py | 9 |
24 files changed, 467 insertions, 456 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 886e132e10..2473a2b2bb 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,7 +33,7 @@ class Auth(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def check(self, event, raises=False): + def check(self, event, snapshot, raises=False): """ Checks if this event is correctly authed. Returns: @@ -48,7 +48,11 @@ class Auth(object): allowed = yield self.is_membership_change_allowed(event) defer.returnValue(allowed) else: - yield self.check_joined_room(event.room_id, event.user_id) + self._check_joined_room( + member=snapshot.membership_state, + user_id=snapshot.user_id, + room_id=snapshot.room_id, + ) defer.returnValue(True) else: raise AuthError(500, "Unknown event: %s" % event) @@ -66,14 +70,18 @@ class Auth(object): room_id=room_id, user_id=user_id ) - if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s" % - (user_id, room_id)) + self._check_joined_room(member, user_id, room_id) defer.returnValue(member) except AttributeError: pass defer.returnValue(None) + def _check_joined_room(self, member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + @defer.inlineCallbacks def is_membership_change_allowed(self, event): target_user_id = event.state_key diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a89770ed7b..6d292ccf9a 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -296,10 +296,6 @@ def setup(): db_name=db_name, ) - # This object doesn't need to be saved because it's set as the handler for - # the replication layer - hs.get_federation() - hs.register_servlets() hs.create_resource_tree( diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py deleted file mode 100644 index 984c1558e9..0000000000 --- a/synapse/federation/handler.py +++ /dev/null @@ -1,157 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 matrix.org -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from .pdu_codec import PduCodec - -from synapse.api.errors import AuthError -from synapse.util.logutils import log_function - -import logging - - -logger = logging.getLogger(__name__) - - -class FederationEventHandler(object): - """ Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the home server (including auth and state conflict resoultion) - b) converting events that were produced by local clients that may need - to be sent to remote home servers. - """ - - def __init__(self, hs): - self.store = hs.get_datastore() - self.replication_layer = hs.get_replication_layer() - self.state_handler = hs.get_state_handler() - # self.auth_handler = gs.get_auth_handler() - self.event_handler = hs.get_handlers().federation_handler - self.server_name = hs.hostname - - self.lock_manager = hs.get_room_lock_manager() - - self.replication_layer.set_handler(self) - - self.pdu_codec = PduCodec(hs) - - @log_function - @defer.inlineCallbacks - def handle_new_event(self, event): - """ Takes in an event from the client to server side, that has already - been authed and handled by the state module, and sends it to any - remote home servers that may be interested. - - Args: - event - - Returns: - Deferred: Resolved when it has successfully been queued for - processing. - """ - yield self.fill_out_prev_events(event) - - pdu = self.pdu_codec.pdu_from_event(event) - - if not hasattr(pdu, "destinations") or not pdu.destinations: - pdu.destinations = [] - - yield self.replication_layer.send_pdu(pdu) - - @log_function - @defer.inlineCallbacks - def backfill(self, dest, room_id, limit): - pdus = yield self.replication_layer.backfill(dest, room_id, limit) - - if not pdus: - defer.returnValue([]) - - events = [ - self.pdu_codec.event_from_pdu(pdu) - for pdu in pdus - ] - - defer.returnValue(events) - - @log_function - def get_state_for_room(self, destination, room_id): - return self.replication_layer.get_state_for_context( - destination, room_id - ) - - @log_function - @defer.inlineCallbacks - def on_receive_pdu(self, pdu, backfilled): - """ Called by the ReplicationLayer when we have a new pdu. We need to - do auth checks and put it throught the StateHandler. - """ - event = self.pdu_codec.event_from_pdu(pdu) - - try: - with (yield self.lock_manager.lock(pdu.context)): - if event.is_state and not backfilled: - is_new_state = yield self.state_handler.handle_new_state( - pdu - ) - if not is_new_state: - return - else: - is_new_state = False - - yield self.event_handler.on_receive(event, is_new_state, backfilled) - - except AuthError: - # TODO: Implement something in federation that allows us to - # respond to PDU. - raise - - return - - @defer.inlineCallbacks - def _on_new_state(self, pdu, new_state_event): - # TODO: Do any store stuff here. Notifiy C2S about this new - # state. - - yield self.store.update_current_state( - pdu_id=pdu.pdu_id, - origin=pdu.origin, - context=pdu.context, - pdu_type=pdu.pdu_type, - state_key=pdu.state_key - ) - - yield self.event_handler.on_receive(new_state_event) - - @defer.inlineCallbacks - def fill_out_prev_events(self, event): - if hasattr(event, "prev_events"): - return - - results = yield self.store.get_latest_pdus_in_context( - event.room_id - ) - - es = [ - "%s@%s" % (p_id, origin) for p_id, origin, _ in results - ] - - event.prev_events = [e for e in es if e != event.event_id] - - if results: - event.depth = max([int(v) for _, _, v in results]) + 1 - else: - event.depth = 0 diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index e0e4de4e8c..4cf72b2e42 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -25,7 +25,6 @@ from .units import Pdu from synapse.util.logutils import log_function -import copy import json import logging @@ -41,28 +40,6 @@ class PduActions(object): self.store = datastore @log_function - def persist_received(self, pdu): - """ Persists the given `Pdu` that was received from a remote home - server. - - Returns: - Deferred - """ - return self._persist(pdu) - - @defer.inlineCallbacks - @log_function - def persist_outgoing(self, pdu): - """ Persists the given `Pdu` that this home server created. - - Returns: - Deferred - """ - ret = yield self._persist(pdu) - - defer.returnValue(ret) - - @log_function def mark_as_processed(self, pdu): """ Persist the fact that we have fully processed the given `Pdu` @@ -73,25 +50,6 @@ class PduActions(object): @defer.inlineCallbacks @log_function - def populate_previous_pdus(self, pdu): - """ Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s - that we have received. - - Returns: - Deferred - """ - results = yield self.store.get_latest_pdus_in_context(pdu.context) - - pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results] - - vs = [int(v) for _, _, v in results] - if vs: - pdu.depth = max(vs) + 1 - else: - pdu.depth = 0 - - @defer.inlineCallbacks - @log_function def after_transaction(self, transaction_id, destination, origin): """ Returns all `Pdu`s that we sent to the given remote home server after a given transaction id. @@ -143,28 +101,6 @@ class PduActions(object): depth=pdu.depth ) - @defer.inlineCallbacks - @log_function - def _persist(self, pdu): - kwargs = copy.copy(pdu.__dict__) - unrec_keys = copy.copy(pdu.unrecognized_keys) - del kwargs["content"] - kwargs["content_json"] = json.dumps(pdu.content) - kwargs["unrecognized_keys"] = json.dumps(unrec_keys) - - logger.debug("Persisting: %s", repr(kwargs)) - - if pdu.is_state: - ret = yield self.store.persist_state(**kwargs) - else: - ret = yield self.store.persist_pdu(**kwargs) - - yield self.store.update_min_depth_for_context( - pdu.context, pdu.depth - ) - - defer.returnValue(ret) - class TransactionActions(object): """ Defines persistence actions that relate to handling Transactions. diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index cf634a64b2..38ae360bcd 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -134,10 +134,8 @@ class ReplicationLayer(object): logger.debug("[%s] Persisting PDU", pdu.pdu_id) - #yield self.pdu_actions.populate_previous_pdus(pdu) - # Save *before* trying to send - yield self.pdu_actions.persist_outgoing(pdu) + yield self.store.persist_event(pdu=pdu) logger.debug("[%s] Persisted PDU", pdu.pdu_id) logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id) @@ -450,7 +448,7 @@ class ReplicationLayer(object): logger.exception("Failed to get PDU") # Persist the Pdu, but don't mark it as processed yet. - yield self.pdu_actions.persist_received(pdu) + yield self.store.persist_event(pdu=pdu) if not backfilled: ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 3f07b5aa4a..f141e92ce2 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from twisted.internet import defer class BaseHandler(object): @@ -26,3 +26,25 @@ class BaseHandler(object): self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() self.hs = hs + + +class BaseRoomHandler(BaseHandler): + + @defer.inlineCallbacks + def _on_new_room_event(self, event, snapshot, extra_destinations=[], + extra_users=[]): + snapshot.fill_out_prev_events(event) + + store_id = yield self.store.persist_event(event) + + destinations = set(extra_destinations) + # Send a PDU to all hosts who have joined the room. + destinations.update((yield self.store.get_joined_hosts_for_room( + event.room_id + ))) + event.destinations = list(destinations) + + self.notifier.on_new_room_event(event, extra_users=extra_users) + + federation_handler = self.hs.get_handlers().federation_handler + yield federation_handler.handle_new_event(event, snapshot) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 62edd5dbdc..9023c3d403 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,6 +20,9 @@ from ._base import BaseHandler from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent from synapse.api.constants import Membership from synapse.util.logutils import log_function +from synapse.federation.pdu_codec import PduCodec + +from synapse.api.errors import AuthError from twisted.internet import defer @@ -30,8 +33,14 @@ logger = logging.getLogger(__name__) class FederationHandler(BaseHandler): + """Handles events that originated from federation. + Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the home server (including auth and state conflict resoultion) + b) converting events that were produced by local clients that may need + to be sent to remote home servers. + """ - """Handles events that originated from federation.""" def __init__(self, hs): super(FederationHandler, self).__init__(hs) @@ -42,9 +51,67 @@ class FederationHandler(BaseHandler): self.waiting_for_join_list = {} + self.store = hs.get_datastore() + self.replication_layer = hs.get_replication_layer() + self.state_handler = hs.get_state_handler() + # self.auth_handler = gs.get_auth_handler() + self.server_name = hs.hostname + + self.lock_manager = hs.get_room_lock_manager() + + self.replication_layer.set_handler(self) + + self.pdu_codec = PduCodec(hs) + @log_function @defer.inlineCallbacks - def on_receive(self, event, is_new_state, backfilled): + def handle_new_event(self, event, snapshot): + """ Takes in an event from the client to server side, that has already + been authed and handled by the state module, and sends it to any + remote home servers that may be interested. + + Args: + event + snapshot (.storage.Snapshot): THe snapshot the event happened after + + Returns: + Deferred: Resolved when it has successfully been queued for + processing. + """ + + pdu = self.pdu_codec.pdu_from_event(event) + + if not hasattr(pdu, "destinations") or not pdu.destinations: + pdu.destinations = [] + + yield self.replication_layer.send_pdu(pdu) + + @log_function + def get_state_for_room(self, destination, room_id): + return self.replication_layer.get_state_for_context( + destination, room_id + ) + + @log_function + @defer.inlineCallbacks + def on_receive_pdu(self, pdu, backfilled): + """ Called by the ReplicationLayer when we have a new pdu. We need to + do auth checks and put it throught the StateHandler. + """ + event = self.pdu_codec.event_from_pdu(pdu) + + with (yield self.lock_manager.lock(pdu.context)): + if event.is_state and not backfilled: + is_new_state = yield self.state_handler.handle_new_state( + pdu + ) + if not is_new_state: + return + else: + is_new_state = False + # TODO: Implement something in federation that allows us to + # respond to PDU. + if hasattr(event, "state_key") and not is_new_state: logger.debug("Ignoring old state.") return @@ -86,8 +153,7 @@ class FederationHandler(BaseHandler): if not room: # Huh, let's try and get the current state try: - federation = self.hs.get_federation() - yield federation.get_state_for_room( + yield self.get_state_for_room( event.origin, event.room_id ) @@ -119,11 +185,10 @@ class FederationHandler(BaseHandler): "user_joined_room", user=user, room_id=event.room_id ) - @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit): - events = yield self.hs.get_federation().backfill(dest, room_id, limit) + events = yield self._backfill(dest, room_id, limit) for event in events: try: @@ -133,10 +198,23 @@ class FederationHandler(BaseHandler): defer.returnValue(events) + @defer.inlineCallbacks + def _backfill(self, dest, room_id, limit): + pdus = yield self.replication_layer.backfill(dest, room_id, limit) + + if not pdus: + defer.returnValue([]) + + events = [ + self.pdu_codec.event_from_pdu(pdu) + for pdu in pdus + ] + + defer.returnValue(events) + @log_function @defer.inlineCallbacks - def do_invite_join(self, target_host, room_id, joinee, content): - federation = self.hs.get_federation() + def do_invite_join(self, target_host, room_id, joinee, content, snapshot): hosts = yield self.store.get_joined_hosts_for_room(room_id) if self.hs.hostname in hosts: @@ -146,7 +224,7 @@ class FederationHandler(BaseHandler): # First get current state to see if we are already joined. try: - yield federation.get_state_for_room(target_host, room_id) + yield self.get_state_for_room(target_host, room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id) if self.hs.hostname in hosts: @@ -166,7 +244,8 @@ class FederationHandler(BaseHandler): new_event.destinations = [target_host] - yield federation.handle_new_event(new_event) + snapshot.fill_out_prev_events(new_event) + yield self.handle_new_event(new_event, snapshot) # TODO (erikj): Time out here. d = defer.Deferred() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 760373344d..f01349b339 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -24,14 +24,14 @@ from synapse.api.events.room import ( ) from synapse.streams.config import PaginationConfig from synapse.util import stringutils -from ._base import BaseHandler +from ._base import BaseRoomHandler import logging logger = logging.getLogger(__name__) -class MessageHandler(BaseHandler): +class MessageHandler(BaseRoomHandler): def __init__(self, hs): super(MessageHandler, self).__init__(hs) @@ -83,20 +83,12 @@ class MessageHandler(BaseHandler): if stamp_event: event.content["hsob_ts"] = int(self.clock.time_msec()) - with (yield self.room_lock.lock(event.room_id)): - if not suppress_auth: - yield self.auth.check(event, raises=True) + snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - # store message in db - store_id = yield self.store.persist_event(event) + if not suppress_auth: + yield self.auth.check(event, snapshot, raises=True) - event.destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - - self.notifier.on_new_room_event(event) - - yield self.hs.get_federation().handle_new_event(event) + yield self._on_new_room_event(event, snapshot) @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, @@ -144,23 +136,16 @@ class MessageHandler(BaseHandler): SynapseError if something went wrong. """ - with (yield self.room_lock.lock(event.room_id)): - yield self.auth.check(event, raises=True) + snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - if stamp_event: - event.content["hsob_ts"] = int(self.clock.time_msec()) + yield self.auth.check(event, snapshot, raises=True) - yield self.state_handler.handle_new_event(event) - - # store in db - store_id = yield self.store.persist_event(event) + if stamp_event: + event.content["hsob_ts"] = int(self.clock.time_msec()) - event.destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - self.notifier.on_new_room_event(event) + yield self.state_handler.handle_new_event(event, snapshot) - yield self.hs.get_federation().handle_new_event(event) + yield self._on_new_room_event(event, snapshot) @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, @@ -229,18 +214,12 @@ class MessageHandler(BaseHandler): if stamp_event: event.content["hsob_ts"] = int(self.clock.time_msec()) - with (yield self.room_lock.lock(event.room_id)): - yield self.auth.check(event, raises=True) - - # store message in db - store_id = yield self.store.persist_event(event) + snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - event.destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - yield self.hs.get_federation().handle_new_event(event) + yield self.auth.check(event, snapshot, raises=True) - self.notifier.on_new_room_event(event) + # store message in db + yield self._on_new_room_event(event, snapshot) @defer.inlineCallbacks def snapshot_all_rooms(self, user_id=None, pagin_config=None, @@ -324,7 +303,7 @@ class MessageHandler(BaseHandler): defer.returnValue(ret) -class RoomCreationHandler(BaseHandler): +class RoomCreationHandler(BaseRoomHandler): @defer.inlineCallbacks def create_room(self, user_id, room_id, config): @@ -395,6 +374,13 @@ class RoomCreationHandler(BaseHandler): content=config, ) + snapshot = yield self.store.snapshot_room( + room_id=room_id, + user_id=user_id, + state_type=RoomConfigEvent.TYPE, + state_key="", + ) + if room_alias: yield self.store.create_room_alias_association( room_id=room_id, @@ -402,9 +388,11 @@ class RoomCreationHandler(BaseHandler): servers=[self.hs.hostname], ) - yield self.state_handler.handle_new_event(config_event) + yield self.state_handler.handle_new_event(config_event, snapshot) + # store_id = persist... - yield self.hs.get_federation().handle_new_event(config_event) + federation_handler = self.hs.get_handlers().federation_handler + yield federation_handler.handle_new_event(config_event, snapshot) content = {"membership": Membership.JOIN} join_event = self.event_factory.create_event( @@ -428,7 +416,7 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) -class RoomMemberHandler(BaseHandler): +class RoomMemberHandler(BaseRoomHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns @@ -539,6 +527,11 @@ class RoomMemberHandler(BaseHandler): """ target_user_id = event.state_key + snapshot = yield self.store.snapshot_room( + event.room_id, event.user_id, + RoomMemberEvent.TYPE, target_user_id + ) + ## TODO(markjh): get prev state from snapshot. prev_state = yield self.store.get_room_member( target_user_id, event.room_id ) @@ -559,24 +552,22 @@ class RoomMemberHandler(BaseHandler): # if this HS is not currently in the room, i.e. we have to do the # invite/join dance. if event.membership == Membership.JOIN: - yield self._do_join(event, do_auth=do_auth) + yield self._do_join(event, snapshot, do_auth=do_auth) else: # This is not a JOIN, so we can handle it normally. if do_auth: - yield self.auth.check(event, raises=True) + yield self.auth.check(event, snapshot, raises=True) - prev_state = yield self.store.get_room_member( - target_user_id, event.room_id - ) if prev_state and prev_state.membership == event.membership: # double same action, treat this event as a NOOP. defer.returnValue({}) return - yield self.state_handler.handle_new_event(event) + yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], + snapshot=snapshot, ) defer.returnValue({"room_id": room_id}) @@ -606,12 +597,16 @@ class RoomMemberHandler(BaseHandler): content=content, ) - yield self._do_join(new_event, room_host=host, do_auth=True) + snapshot = yield self.store.snapshot_room( + room_id, joinee, RoomMemberEvent.TYPE, joinee + ) + + yield self._do_join(new_event, snapshot, room_host=host, do_auth=True) defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def _do_join(self, event, room_host=None, do_auth=True): + def _do_join(self, event, snapshot, room_host=None, do_auth=True): joinee = self.hs.parse_userid(event.state_key) # room_id = RoomID.from_string(event.room_id, self.hs) room_id = event.room_id @@ -633,6 +628,7 @@ class RoomMemberHandler(BaseHandler): elif room_host: should_do_dance = True else: + # TODO(markjh): get prev_state from snapshot prev_state = yield self.store.get_room_member( joinee.to_string(), room_id ) @@ -652,7 +648,7 @@ class RoomMemberHandler(BaseHandler): if should_do_dance: handler = self.hs.get_handlers().federation_handler have_joined = yield handler.do_invite_join( - room_host, room_id, event.user_id, event.content + room_host, room_id, event.user_id, event.content, snapshot ) # We want to do the _do_update inside the room lock. @@ -660,12 +656,13 @@ class RoomMemberHandler(BaseHandler): logger.debug("Doing normal join") if do_auth: - yield self.auth.check(event, raises=True) + yield self.auth.check(event, snapshot, raises=True) - yield self.state_handler.handle_new_event(event) + yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], + snapshot=snapshot, ) user = self.hs.parse_userid(event.user_id) @@ -709,15 +706,8 @@ class RoomMemberHandler(BaseHandler): defer.returnValue([r.room_id for r in rooms]) - @defer.inlineCallbacks - def _do_local_membership_update(self, event, membership): - # store membership - store_id = yield self.store.persist_event(event) - - # Send a PDU to all hosts who have joined the room. - destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) + def _do_local_membership_update(self, event, membership, snapshot): + destinations = [] # If we're inviting someone, then we should also send it to that # HS. @@ -732,13 +722,12 @@ class RoomMemberHandler(BaseHandler): host = target_user.domain destinations.append(host) - event.destinations = list(set(destinations)) - - yield self.hs.get_federation().handle_new_event(event) - self.notifier.on_new_room_event(event, extra_users=[target_user]) - + return self._on_new_room_event( + event, snapshot, extra_destinations=destinations, + extra_users=[target_user] + ) -class RoomListHandler(BaseHandler): +class RoomListHandler(BaseRoomHandler): @defer.inlineCallbacks def get_public_room_list(self): diff --git a/synapse/server.py b/synapse/server.py index c29c61220d..ade8dc6c15 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,7 +20,6 @@ # Imports required for the default HomeServer() implementation from synapse.federation import initialize_http_replication -from synapse.federation.handler import FederationEventHandler from synapse.api.events.factory import EventFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -59,7 +58,6 @@ class BaseHomeServer(object): 'http_client', 'db_pool', 'persistence_service', - 'federation', 'replication_layer', 'datastore', 'event_factory', @@ -167,9 +165,6 @@ class HomeServer(BaseHomeServer): def build_replication_layer(self): return initialize_http_replication(self) - def build_federation(self): - return FederationEventHandler(self) - def build_datastore(self): return DataStore(self) diff --git a/synapse/state.py b/synapse/state.py index ca8e1ca630..e1a1a159bb 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -45,7 +45,7 @@ class StateHandler(object): @defer.inlineCallbacks @log_function - def handle_new_event(self, event): + def handle_new_event(self, event, snapshot): """ Given an event this works out if a) we have sufficient power level to update the state and b) works out what the prev_state should be. @@ -70,25 +70,13 @@ class StateHandler(object): # Now I need to fill out the prev state and work out if it has auth # (w.r.t. to power levels) - results = yield self.store.get_latest_pdus_in_context( - event.room_id - ) + snapshot.fill_out_prev_events(event) event.prev_events = [ - encode_event_id(p_id, origin) for p_id, origin, _ in results - ] - event.prev_events = [ e for e in event.prev_events if e != event.event_id ] - if results: - event.depth = max([int(v) for _, _, v in results]) + 1 - else: - event.depth = 0 - - current_state = yield self.store.get_current_state_pdu( - key.context, key.type, key.state_key - ) + current_state = snapshot.prev_state_pdu if current_state: event.prev_state = encode_event_id( diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 38ab03c45c..e8faba3eeb 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -57,20 +57,22 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks @log_function - def persist_event(self, event, backfilled=False): - if event.type == RoomMemberEvent.TYPE: - yield self._store_room_member(event) - elif event.type == FeedbackEvent.TYPE: - yield self._store_feedback(event) -# elif event.type == RoomConfigEvent.TYPE: -# yield self._store_room_config(event) - elif event.type == RoomNameEvent.TYPE: - yield self._store_room_name(event) - elif event.type == RoomTopicEvent.TYPE: - yield self._store_room_topic(event) - - ret = yield self._store_event(event, backfilled) - defer.returnValue(ret) + def persist_event(self, event=None, backfilled=False, pdu=None): + stream_ordering = None + if backfilled: + if not self.min_token_deferred.called: + yield self.min_token_deferred + self.min_token -= 1 + stream_ordering = self.min_token + + latest = yield self._db_pool.runInteraction( + self._persist_pdu_event_txn, + pdu=pdu, + event=event, + backfilled=backfilled, + stream_ordering=stream_ordering, + ) + defer.returnValue(latest) @defer.inlineCallbacks def get_event(self, event_id): @@ -89,12 +91,44 @@ class DataStore(RoomMemberStore, RoomStore, event = self._parse_event_from_row(events_dict) defer.returnValue(event) - @defer.inlineCallbacks + def _persist_pdu_event_txn(self, txn, pdu=None, event=None, + backfilled=False, stream_ordering=None): + if pdu is not None: + self._persist_event_pdu_txn(txn, pdu) + if event is not None: + return self._persist_event_txn( + txn, event, backfilled, stream_ordering + ) + + def _persist_event_pdu_txn(self, txn, pdu): + cols = dict(pdu.__dict__) + unrec_keys = dict(pdu.unrecognized_keys) + del cols["content"] + del cols["prev_pdus"] + cols["content_json"] = json.dumps(pdu.content) + cols["unrecognized_keys"] = json.dumps(unrec_keys) + + logger.debug("Persisting: %s", repr(cols)) + + if pdu.is_state: + self._persist_state_txn(txn, pdu.prev_pdus, cols) + else: + self._persist_pdu_txn(txn, pdu.prev_pdus, cols) + + self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth) + @log_function - def _store_event(self, event, backfilled): - # FIXME (erikj): This should be removed when we start amalgamating - # event and pdu storage - yield self.hs.get_federation().fill_out_prev_events(event) + def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None): + if event.type == RoomMemberEvent.TYPE: + self._store_room_member_txn(txn, event) + elif event.type == FeedbackEvent.TYPE: + self._store_feedback_txn(txn,event) +# elif event.type == RoomConfigEvent.TYPE: +# self._store_room_config_txn(txn, event) + elif event.type == RoomNameEvent.TYPE: + self._store_room_name_txn(txn, event) + elif event.type == RoomTopicEvent.TYPE: + self._store_room_topic_txn(txn, event) vals = { "topological_ordering": event.depth, @@ -105,17 +139,14 @@ class DataStore(RoomMemberStore, RoomStore, "processed": True, } + if stream_ordering is not None: + vals["stream_ordering"] = stream_ordering + if hasattr(event, "outlier"): vals["outlier"] = event.outlier else: vals["outlier"] = False - if backfilled: - if not self.min_token_deferred.called: - yield self.min_token_deferred - self.min_token -= 1 - vals["stream_ordering"] = self.min_token - unrec = { k: v for k, v in event.get_full_dict().items() @@ -124,7 +155,7 @@ class DataStore(RoomMemberStore, RoomStore, vals["unrecognized_keys"] = json.dumps(unrec) try: - yield self._simple_insert("events", vals) + self._simple_insert_txn(txn, "events", vals) except: logger.exception( "Failed to persist, probably duplicate: %s", @@ -143,9 +174,10 @@ class DataStore(RoomMemberStore, RoomStore, if hasattr(event, "prev_state"): vals["prev_state"] = event.prev_state - yield self._simple_insert("state_events", vals) + self._simple_insert_txn(txn, "state_events", vals) - yield self._simple_insert( + self._simple_insert_txn( + txn, "current_state_events", { "event_id": event.event_id, @@ -155,8 +187,7 @@ class DataStore(RoomMemberStore, RoomStore, } ) - latest = yield self.get_room_events_max_id() - defer.returnValue(latest) + return self._get_room_events_max_id_txn(txn) @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -192,6 +223,85 @@ class DataStore(RoomMemberStore, RoomStore, defer.returnValue(self.min_token) + def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): + """Snapshot the room for an update by a user + Args: + room_id (synapse.types.RoomId): The room to snapshot. + user_id (synapse.types.UserId): The user to snapshot the room for. + state_type (str): Optional state type to snapshot. + state_key (str): Optional state key to snapshot. + Returns: + synapse.storage.Snapshot: A snapshot of the state of the room. + """ + def _snapshot(txn): + membership_state = self._get_room_member(txn, user_id, room_id) + prev_pdus = self._get_latest_pdus_in_context( + txn, room_id + ) + if state_type is not None and state_key is not None: + prev_state_pdu = self._get_current_state_pdu( + txn, room_id, state_type, state_key + ) + else: + prev_state_pdu = None + + return Snapshot( + store=self, + room_id=room_id, + user_id=user_id, + prev_pdus=prev_pdus, + membership_state=membership_state, + state_type=state_type, + state_key=state_key, + prev_state_pdu=prev_state_pdu, + ) + + return self._db_pool.runInteraction(_snapshot) + + +class Snapshot(object): + """Snapshot of the state of a room + Args: + store (DataStore): The datastore. + room_id (RoomId): The room of the snapshot. + user_id (UserId): The user this snapshot is for. + prev_pdus (list): The list of PDU ids this snapshot is after. + membership_state (RoomMemberEvent): The current state of the user in + the room. + state_type (str, optional): State type captured by the snapshot + state_key (str, optional): State key captured by the snapshot + prev_state_pdu (PduEntry, optional): pdu id of + the previous value of the state type and key in the room. + """ + + def __init__(self, store, room_id, user_id, prev_pdus, + membership_state, state_type=None, state_key=None, + prev_state_pdu=None): + self.store = store + self.room_id = room_id + self.user_id = user_id + self.prev_pdus = prev_pdus + self.membership_state = membership_state + self.state_type = state_type + self.state_key = state_key + self.prev_state_pdu = prev_state_pdu + + def fill_out_prev_events(self, event): + if hasattr(event, "prev_events"): + return + + es = [ + "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus + ] + + event.prev_events = [e for e in es if e != event.event_id] + + if self.prev_pdus: + event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1 + else: + event.depth = 0 + + def schema_path(schema): """ Get a filesystem path for the named database schema diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 75aab2d3b9..33d56f47ce 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -86,16 +86,18 @@ class SQLBaseStore(object): table : string giving the table name values : dict of new column names and values for them """ + return self._db_pool.runInteraction( + self._simple_insert_txn, table, values, + ) + + def _simple_insert_txn(self, txn, table, values): sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in values), ", ".join("?" for k in values) ) - - def func(txn): - txn.execute(sql, values.values()) - return txn.lastrowid - return self._db_pool.runInteraction(func) + txn.execute(sql, values.values()) + return txn.lastrowid def _simple_select_one(self, table, keyvalues, retcols, allow_none=False): diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py index 513b72d279..bac3dea955 100644 --- a/synapse/storage/feedback.py +++ b/synapse/storage/feedback.py @@ -20,8 +20,8 @@ from ._base import SQLBaseStore class FeedbackStore(SQLBaseStore): - def _store_feedback(self, event): - return self._simple_insert("feedback", { + def _store_feedback_txn(self, txn, event): + self._simple_insert_txn(txn, "feedback", { "event_id": event.event_id, "feedback_type": event.content["type"], "room_id": event.room_id, diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 7655f43ede..9fd44f2454 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -114,7 +114,7 @@ class PduStore(SQLBaseStore): return self._get_pdu_tuples(txn, res) - def persist_pdu(self, prev_pdus, **cols): + def _persist_pdu_txn(self, txn, prev_pdus, cols): """Inserts a (non-state) PDU into the database. Args: @@ -122,11 +122,6 @@ class PduStore(SQLBaseStore): prev_pdus (list) **cols: The columns to insert into the PdusTable. """ - return self._db_pool.runInteraction( - self._persist_pdu, prev_pdus, cols - ) - - def _persist_pdu(self, txn, prev_pdus, cols): entry = PdusTable.EntryType( **{k: cols.get(k, None) for k in PdusTable.fields} ) @@ -262,7 +257,7 @@ class PduStore(SQLBaseStore): return row[0] if row else None - def update_min_depth_for_context(self, context, depth): + def _update_min_depth_for_context_txn(self, txn, context, depth): """Update the minimum `depth` of the given context, which is the line on which we stop backfilling backwards. @@ -270,11 +265,6 @@ class PduStore(SQLBaseStore): context (str) depth (int) """ - return self._db_pool.runInteraction( - self._update_min_depth_for_context, context, depth - ) - - def _update_min_depth_for_context(self, txn, context, depth): min_depth = self._get_min_depth_interaction(txn, context) do_insert = depth < min_depth if min_depth else True @@ -286,7 +276,7 @@ class PduStore(SQLBaseStore): (context, depth) ) - def get_latest_pdus_in_context(self, context): + def _get_latest_pdus_in_context(self, txn, context): """Get's a list of the most current pdus for a given context. This is used when we are sending a Pdu and need to fill out the `prev_pdus` key @@ -295,11 +285,6 @@ class PduStore(SQLBaseStore): txn context """ - return self._db_pool.runInteraction( - self._get_latest_pdus_in_context, context - ) - - def _get_latest_pdus_in_context(self, txn, context): query = ( "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " @@ -485,7 +470,7 @@ class StatePduStore(SQLBaseStore): """A collection of queries for handling state PDUs. """ - def persist_state(self, prev_pdus, **cols): + def _persist_state_txn(self, txn, prev_pdus, cols): """Inserts a state PDU into the database Args: @@ -493,12 +478,6 @@ class StatePduStore(SQLBaseStore): prev_pdus (list) **cols: The columns to insert into the PdusTable and StatePdusTable """ - - return self._db_pool.runInteraction( - self._persist_state, prev_pdus, cols - ) - - def _persist_state(self, txn, prev_pdus, cols): pdu_entry = PdusTable.EntryType( **{k: cols.get(k, None) for k in PdusTable.fields} ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index a5751005ef..d1f1a232f8 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -129,8 +129,9 @@ class RoomStore(SQLBaseStore): defer.returnValue(ret) - def _store_room_topic(self, event): - return self._simple_insert( + def _store_room_topic_txn(self, txn, event): + self._simple_insert_txn( + txn, "topics", { "event_id": event.event_id, @@ -139,8 +140,9 @@ class RoomStore(SQLBaseStore): } ) - def _store_room_name(self, event): - return self._simple_insert( + def _store_room_name_txn(self, txn, event): + self._simple_insert_txn( + txn, "room_names", { "event_id": event.event_id, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4ad37af0f3..2746126e85 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,14 +26,14 @@ logger = logging.getLogger(__name__) class RoomMemberStore(SQLBaseStore): - @defer.inlineCallbacks - def _store_room_member(self, event): + def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ target_user_id = event.state_key domain = self.hs.parse_userid(target_user_id).domain - yield self._simple_insert( + self._simple_insert_txn( + txn, "room_memberships", { "event_id": event.event_id, @@ -50,13 +50,13 @@ class RoomMemberStore(SQLBaseStore): "INSERT OR IGNORE INTO room_hosts (room_id, host) " "VALUES (?, ?)" ) - yield self._execute(None, sql, event.room_id, domain) + txn.execute(sql, (event.room_id, domain)) else: sql = ( "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" ) - yield self._execute(None, sql, event.room_id, domain) + txn.execute(sql, (event.room_id, domain)) @defer.inlineCallbacks def get_room_member(self, user_id, room_id): @@ -75,6 +75,24 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(rows[0] if rows else None) + def _get_room_member(self, txn, user_id, room_id): + sql = ( + "SELECT e.* FROM events as e" + " INNER JOIN room_memberships as m" + " ON e.event_id = m.event_id" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id" + " WHERE m.user_id = ? and e.room_id = ?" + " LIMIT 1" + ) + txn.execute(sql, (user_id, room_id)) + rows = self.cursor_to_dict(txn) + if rows: + return self._parse_event_from_row(rows[0]) + else: + return None + + def get_room_members(self, room_id, membership=None): """Retrieve the current room member list for a room. diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 6a22d5aead..4f42afc015 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -281,17 +281,20 @@ class StreamStore(SQLBaseStore): ) ) - @defer.inlineCallbacks def get_room_events_max_id(self): - res = yield self._execute_and_decode( + return self._db_pool.runInteraction(self._get_room_events_max_id_txn) + + def _get_room_events_max_id_txn(self, txn): + txn.execute( "SELECT MAX(stream_ordering) as m FROM events" ) + res = self.cursor_to_dict(txn) + logger.debug("get_room_events_max_id: %s", res) if not res or not res[0] or not res[0]["m"]: - defer.returnValue("s1") - return + return "s1" key = res[0]["m"] - defer.returnValue("s%d" % (key,)) + return "s%d" % (key,) diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py index 58590e4fcd..938b57bec9 100644 --- a/tests/federation/test_federation.py +++ b/tests/federation/test_federation.py @@ -58,7 +58,7 @@ class FederationTestCase(unittest.TestCase): self.mock_persistence = Mock(spec=[ "get_current_state_for_context", "get_pdu", - "persist_pdu", + "persist_event", "update_min_depth_for_context", "prep_send_transaction", "delivered_txn", diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 7f7e2b63a0..bc260c8aab 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -22,8 +22,9 @@ from synapse.api.events.room import ( from synapse.api.constants import Membership from synapse.handlers.federation import FederationHandler from synapse.server import HomeServer +from synapse.federation.units import Pdu -from mock import NonCallableMock +from mock import NonCallableMock, ANY import logging @@ -60,36 +61,41 @@ class FederationTestCase(unittest.TestCase): @defer.inlineCallbacks def test_msg(self): - event = self.hs.get_event_factory().create_event( - etype=MessageEvent.TYPE, - msg_id="bob", - room_id="foo", + pdu = Pdu( + pdu_type=MessageEvent.TYPE, + context="foo", content={"msgtype": u"fooo"}, + ts=0, + pdu_id="a", + origin="b", ) store_id = "ASD" self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.get_room.return_value = defer.succeed(True) - yield self.handlers.federation_handler.on_receive(event, False, False) + yield self.handlers.federation_handler.on_receive_pdu(pdu, False) - self.datastore.persist_event.assert_called_once_with(event, False) - self.notifier.on_new_room_event.assert_called_once_with(event) + self.datastore.persist_event.assert_called_once_with(ANY, False) + self.notifier.on_new_room_event.assert_called_once_with(ANY) @defer.inlineCallbacks def test_invite_join_target_this(self): room_id = "foo" user_id = "@bob:red" - event = self.hs.get_event_factory().create_event( - etype=InviteJoinEvent.TYPE, + pdu = Pdu( + pdu_type=InviteJoinEvent.TYPE, user_id=user_id, target_host=self.hostname, - room_id=room_id, + context=room_id, content={}, + ts=0, + pdu_id="a", + origin="b", ) - yield self.handlers.federation_handler.on_receive(event, False, False) + yield self.handlers.federation_handler.on_receive_pdu(pdu, False) mem_handler = self.handlers.room_member_handler self.assertEquals(1, mem_handler.change_membership.call_count) @@ -106,15 +112,18 @@ class FederationTestCase(unittest.TestCase): room_id = "foo" user_id = "@bob:red" - event = self.hs.get_event_factory().create_event( - etype=InviteJoinEvent.TYPE, + pdu = Pdu( + pdu_type=InviteJoinEvent.TYPE, user_id=user_id, - target_user_id="@red:not%s" % self.hostname, - room_id=room_id, + state_key="@red:not%s" % self.hostname, + context=room_id, content={}, + ts=0, + pdu_id="a", + origin="b", ) - yield self.handlers.federation_handler.on_receive(event, False, False) + yield self.handlers.federation_handler.on_receive_pdu(pdu, False) mem_handler = self.handlers.room_member_handler self.assertEquals(0, mem_handler.change_membership.call_count) diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index 04d88cb199..09d2a92e16 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): "get_room_member", "get_room", "store_room", + "snapshot_room", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -52,29 +53,36 @@ class RoomMemberHandlerTestCase(unittest.TestCase): handlers=NonCallableMock(spec_set=[ "room_member_handler", "profile_handler", + "federation_handler", ]), auth=NonCallableMock(spec_set=["check"]), - federation=NonCallableMock(spec_set=[ - "handle_new_event", - "get_state_for_room", - ]), state_handler=NonCallableMock(spec_set=["handle_new_event"]), ) + self.federation = NonCallableMock(spec_set=[ + "handle_new_event", + "get_state_for_room", + ]) + self.datastore = hs.get_datastore() self.handlers = hs.get_handlers() self.notifier = hs.get_notifier() - self.federation = hs.get_federation() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() self.hs = hs + self.handlers.federation_handler = self.federation + self.distributor.declare("collect_presencelike_data") self.handlers.room_member_handler = RoomMemberHandler(self.hs) self.handlers.profile_handler = ProfileHandler(self.hs) self.room_member_handler = self.handlers.room_member_handler + self.snapshot = Mock() + self.datastore.snapshot_room.return_value = self.snapshot + + @defer.inlineCallbacks def test_invite(self): room_id = "!foo:red" @@ -104,8 +112,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase): # Actual invocation yield self.room_member_handler.change_membership(event) - self.state_handler.handle_new_event.assert_called_once_with(event) - self.federation.handle_new_event.assert_called_once_with(event) + self.state_handler.handle_new_event.assert_called_once_with( + event, self.snapshot, + ) + self.federation.handle_new_event.assert_called_once_with( + event, self.snapshot, + ) self.assertEquals( set(["blue", "red", "green"]), @@ -116,8 +128,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): event ) self.notifier.on_new_room_event.assert_called_once_with( - event, extra_users=[self.hs.parse_userid(target_user_id)]) - + event, extra_users=[self.hs.parse_userid(target_user_id)] + ) self.assertFalse(self.datastore.get_room.called) self.assertFalse(self.datastore.store_room.called) self.assertFalse(self.federation.get_state_for_room.called) @@ -148,6 +160,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.get_joined_hosts_for_room.side_effect = get_joined + store_id = "store_id_fooo" self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.get_room.return_value = defer.succeed(1) # Not None. @@ -163,8 +176,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase): # Actual invocation yield self.room_member_handler.change_membership(event) - self.state_handler.handle_new_event.assert_called_once_with(event) - self.federation.handle_new_event.assert_called_once_with(event) + self.state_handler.handle_new_event.assert_called_once_with( + event, self.snapshot + ) + self.federation.handle_new_event.assert_called_once_with( + event, self.snapshot + ) self.assertEquals( set(["red", "green"]), @@ -312,27 +329,31 @@ class RoomCreationTest(unittest.TestCase): db_pool=None, datastore=NonCallableMock(spec_set=[ "store_room", + "snapshot_room", ]), http_client=NonCallableMock(spec_set=[]), notifier=NonCallableMock(spec_set=["on_new_room_event"]), handlers=NonCallableMock(spec_set=[ "room_creation_handler", "room_member_handler", + "federation_handler", ]), auth=NonCallableMock(spec_set=["check"]), - federation=NonCallableMock(spec_set=[ - "handle_new_event", - ]), state_handler=NonCallableMock(spec_set=["handle_new_event"]), ) + self.federation = NonCallableMock(spec_set=[ + "handle_new_event", + ]) + self.datastore = hs.get_datastore() self.handlers = hs.get_handlers() self.notifier = hs.get_notifier() - self.federation = hs.get_federation() self.state_handler = hs.get_state_handler() self.hs = hs + self.handlers.federation_handler = self.federation + self.handlers.room_creation_handler = RoomCreationHandler(self.hs) self.room_creation_handler = self.handlers.room_creation_handler diff --git a/tests/rest/test_events.py b/tests/rest/test_events.py index 94ad8910e3..3099a24e8c 100644 --- a/tests/rest/test_events.py +++ b/tests/rest/test_events.py @@ -128,9 +128,9 @@ class EventStreamPermissionsTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), replication_layer=Mock(), state_handler=state_handler, + datastore=MemoryDataStore(), persistence_service=persistence_service, clock=Mock(spec=[ "call_later", @@ -139,9 +139,10 @@ class EventStreamPermissionsTestCase(RestTestCase): ]), ) + hs.get_handlers().federation_handler = Mock() + hs.get_clock().time_msec.return_value = 1000000 - hs.datastore = MemoryDataStore() synapse.rest.register.register_servlets(hs, self.mock_resource) synapse.rest.events.register_servlets(hs, self.mock_resource) synapse.rest.room.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/test_rooms.py b/tests/rest/test_rooms.py index 589b434446..914dc28f53 100644 --- a/tests/rest/test_rooms.py +++ b/tests/rest/test_rooms.py @@ -54,12 +54,12 @@ class RoomPermissionsTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), datastore=MemoryDataStore(), replication_layer=Mock(), state_handler=state_handler, persistence_service=persistence_service, ) + hs.get_handlers().federation_handler = Mock() def _get_user_by_token(token=None): return hs.parse_userid(self.auth_user_id) @@ -401,12 +401,12 @@ class RoomsMemberListTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), datastore=MemoryDataStore(), replication_layer=Mock(), state_handler=state_handler, persistence_service=persistence_service, ) + hs.get_handlers().federation_handler = Mock() self.auth_user_id = self.user_id @@ -479,12 +479,12 @@ class RoomsCreateTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), datastore=MemoryDataStore(), replication_layer=Mock(), state_handler=state_handler, persistence_service=persistence_service, ) + hs.get_handlers().federation_handler = Mock() def _get_user_by_token(token=None): return hs.parse_userid(self.auth_user_id) @@ -569,12 +569,12 @@ class RoomTopicTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), datastore=MemoryDataStore(), replication_layer=Mock(), state_handler=state_handler, persistence_service=persistence_service, ) + hs.get_handlers().federation_handler = Mock() def _get_user_by_token(token=None): return hs.parse_userid(self.auth_user_id) @@ -672,12 +672,12 @@ class RoomMemberStateTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), datastore=MemoryDataStore(), replication_layer=Mock(), state_handler=state_handler, persistence_service=persistence_service, ) + hs.get_handlers().federation_handler = Mock() def _get_user_by_token(token=None): return hs.parse_userid(self.auth_user_id) @@ -797,12 +797,12 @@ class RoomMessagesTestCase(RestTestCase): "test", db_pool=None, http_client=None, - federation=Mock(), datastore=MemoryDataStore(), replication_layer=Mock(), state_handler=state_handler, persistence_service=persistence_service, ) + hs.get_handlers().federation_handler = Mock() def _get_user_by_token(token=None): return hs.parse_userid(self.auth_user_id) diff --git a/tests/test_state.py b/tests/test_state.py index e64d15a3a2..58fd0bf3be 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -243,21 +243,24 @@ class StateTestCase(unittest.TestCase): state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) - tup = ("pdu_id", "origin.com", 5) - pdus = [tup] + snapshot = Mock() + snapshot.prev_state_pdu = state_pdu + event_id = "pdu_id@origin.com" - self.persistence.get_latest_pdus_in_context.return_value = pdus - self.persistence.get_current_state_pdu.return_value = state_pdu + def fill_out_prev_events(event): + event.prev_events = [event_id] + event.depth = 6 + snapshot.fill_out_prev_events = fill_out_prev_events - yield self.state.handle_new_event(event) + yield self.state.handle_new_event(event, snapshot) - self.assertLess(tup[2], event.depth) + self.assertLess(5, event.depth) self.assertEquals(1, len(event.prev_events)) prev_id = event.prev_events[0] - self.assertEqual(encode_event_id(tup[0], tup[1]), prev_id) + self.assertEqual(event_id, prev_id) self.assertEqual( encode_event_id(state_pdu.pdu_id, state_pdu.origin), diff --git a/tests/utils.py b/tests/utils.py index f40cbce51d..6666b06931 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -127,6 +127,15 @@ class MemoryDataStore(object): self.current_state = {} self.events = [] + class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")): + def fill_out_prev_events(self, event): + pass + + def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): + return self.Snapshot( + room_id, user_id, self.get_room_member(user_id, room_id) + ) + def register(self, user_id, token, password_hash): if user_id in self.tokens_to_users.values(): raise StoreError(400, "User in use.") |