diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 886e132e10..2473a2b2bb 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,7 +33,7 @@ class Auth(object):
self.store = hs.get_datastore()
@defer.inlineCallbacks
- def check(self, event, raises=False):
+ def check(self, event, snapshot, raises=False):
""" Checks if this event is correctly authed.
Returns:
@@ -48,7 +48,11 @@ class Auth(object):
allowed = yield self.is_membership_change_allowed(event)
defer.returnValue(allowed)
else:
- yield self.check_joined_room(event.room_id, event.user_id)
+ self._check_joined_room(
+ member=snapshot.membership_state,
+ user_id=snapshot.user_id,
+ room_id=snapshot.room_id,
+ )
defer.returnValue(True)
else:
raise AuthError(500, "Unknown event: %s" % event)
@@ -66,14 +70,18 @@ class Auth(object):
room_id=room_id,
user_id=user_id
)
- if not member or member.membership != Membership.JOIN:
- raise AuthError(403, "User %s not in room %s" %
- (user_id, room_id))
+ self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
except AttributeError:
pass
defer.returnValue(None)
+ def _check_joined_room(self, member, user_id, room_id):
+ if not member or member.membership != Membership.JOIN:
+ raise AuthError(403, "User %s not in room %s (%s)" % (
+ user_id, room_id, repr(member)
+ ))
+
@defer.inlineCallbacks
def is_membership_change_allowed(self, event):
target_user_id = event.state_key
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index a89770ed7b..6d292ccf9a 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -296,10 +296,6 @@ def setup():
db_name=db_name,
)
- # This object doesn't need to be saved because it's set as the handler for
- # the replication layer
- hs.get_federation()
-
hs.register_servlets()
hs.create_resource_tree(
diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py
deleted file mode 100644
index 984c1558e9..0000000000
--- a/synapse/federation/handler.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 matrix.org
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from .pdu_codec import PduCodec
-
-from synapse.api.errors import AuthError
-from synapse.util.logutils import log_function
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class FederationEventHandler(object):
- """ Responsible for:
- a) handling received Pdus before handing them on as Events to the rest
- of the home server (including auth and state conflict resoultion)
- b) converting events that were produced by local clients that may need
- to be sent to remote home servers.
- """
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.replication_layer = hs.get_replication_layer()
- self.state_handler = hs.get_state_handler()
- # self.auth_handler = gs.get_auth_handler()
- self.event_handler = hs.get_handlers().federation_handler
- self.server_name = hs.hostname
-
- self.lock_manager = hs.get_room_lock_manager()
-
- self.replication_layer.set_handler(self)
-
- self.pdu_codec = PduCodec(hs)
-
- @log_function
- @defer.inlineCallbacks
- def handle_new_event(self, event):
- """ Takes in an event from the client to server side, that has already
- been authed and handled by the state module, and sends it to any
- remote home servers that may be interested.
-
- Args:
- event
-
- Returns:
- Deferred: Resolved when it has successfully been queued for
- processing.
- """
- yield self.fill_out_prev_events(event)
-
- pdu = self.pdu_codec.pdu_from_event(event)
-
- if not hasattr(pdu, "destinations") or not pdu.destinations:
- pdu.destinations = []
-
- yield self.replication_layer.send_pdu(pdu)
-
- @log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit):
- pdus = yield self.replication_layer.backfill(dest, room_id, limit)
-
- if not pdus:
- defer.returnValue([])
-
- events = [
- self.pdu_codec.event_from_pdu(pdu)
- for pdu in pdus
- ]
-
- defer.returnValue(events)
-
- @log_function
- def get_state_for_room(self, destination, room_id):
- return self.replication_layer.get_state_for_context(
- destination, room_id
- )
-
- @log_function
- @defer.inlineCallbacks
- def on_receive_pdu(self, pdu, backfilled):
- """ Called by the ReplicationLayer when we have a new pdu. We need to
- do auth checks and put it throught the StateHandler.
- """
- event = self.pdu_codec.event_from_pdu(pdu)
-
- try:
- with (yield self.lock_manager.lock(pdu.context)):
- if event.is_state and not backfilled:
- is_new_state = yield self.state_handler.handle_new_state(
- pdu
- )
- if not is_new_state:
- return
- else:
- is_new_state = False
-
- yield self.event_handler.on_receive(event, is_new_state, backfilled)
-
- except AuthError:
- # TODO: Implement something in federation that allows us to
- # respond to PDU.
- raise
-
- return
-
- @defer.inlineCallbacks
- def _on_new_state(self, pdu, new_state_event):
- # TODO: Do any store stuff here. Notifiy C2S about this new
- # state.
-
- yield self.store.update_current_state(
- pdu_id=pdu.pdu_id,
- origin=pdu.origin,
- context=pdu.context,
- pdu_type=pdu.pdu_type,
- state_key=pdu.state_key
- )
-
- yield self.event_handler.on_receive(new_state_event)
-
- @defer.inlineCallbacks
- def fill_out_prev_events(self, event):
- if hasattr(event, "prev_events"):
- return
-
- results = yield self.store.get_latest_pdus_in_context(
- event.room_id
- )
-
- es = [
- "%s@%s" % (p_id, origin) for p_id, origin, _ in results
- ]
-
- event.prev_events = [e for e in es if e != event.event_id]
-
- if results:
- event.depth = max([int(v) for _, _, v in results]) + 1
- else:
- event.depth = 0
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index e0e4de4e8c..4cf72b2e42 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -25,7 +25,6 @@ from .units import Pdu
from synapse.util.logutils import log_function
-import copy
import json
import logging
@@ -41,28 +40,6 @@ class PduActions(object):
self.store = datastore
@log_function
- def persist_received(self, pdu):
- """ Persists the given `Pdu` that was received from a remote home
- server.
-
- Returns:
- Deferred
- """
- return self._persist(pdu)
-
- @defer.inlineCallbacks
- @log_function
- def persist_outgoing(self, pdu):
- """ Persists the given `Pdu` that this home server created.
-
- Returns:
- Deferred
- """
- ret = yield self._persist(pdu)
-
- defer.returnValue(ret)
-
- @log_function
def mark_as_processed(self, pdu):
""" Persist the fact that we have fully processed the given `Pdu`
@@ -73,25 +50,6 @@ class PduActions(object):
@defer.inlineCallbacks
@log_function
- def populate_previous_pdus(self, pdu):
- """ Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s
- that we have received.
-
- Returns:
- Deferred
- """
- results = yield self.store.get_latest_pdus_in_context(pdu.context)
-
- pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results]
-
- vs = [int(v) for _, _, v in results]
- if vs:
- pdu.depth = max(vs) + 1
- else:
- pdu.depth = 0
-
- @defer.inlineCallbacks
- @log_function
def after_transaction(self, transaction_id, destination, origin):
""" Returns all `Pdu`s that we sent to the given remote home server
after a given transaction id.
@@ -143,28 +101,6 @@ class PduActions(object):
depth=pdu.depth
)
- @defer.inlineCallbacks
- @log_function
- def _persist(self, pdu):
- kwargs = copy.copy(pdu.__dict__)
- unrec_keys = copy.copy(pdu.unrecognized_keys)
- del kwargs["content"]
- kwargs["content_json"] = json.dumps(pdu.content)
- kwargs["unrecognized_keys"] = json.dumps(unrec_keys)
-
- logger.debug("Persisting: %s", repr(kwargs))
-
- if pdu.is_state:
- ret = yield self.store.persist_state(**kwargs)
- else:
- ret = yield self.store.persist_pdu(**kwargs)
-
- yield self.store.update_min_depth_for_context(
- pdu.context, pdu.depth
- )
-
- defer.returnValue(ret)
-
class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions.
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index cf634a64b2..38ae360bcd 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -134,10 +134,8 @@ class ReplicationLayer(object):
logger.debug("[%s] Persisting PDU", pdu.pdu_id)
- #yield self.pdu_actions.populate_previous_pdus(pdu)
-
# Save *before* trying to send
- yield self.pdu_actions.persist_outgoing(pdu)
+ yield self.store.persist_event(pdu=pdu)
logger.debug("[%s] Persisted PDU", pdu.pdu_id)
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
@@ -450,7 +448,7 @@ class ReplicationLayer(object):
logger.exception("Failed to get PDU")
# Persist the Pdu, but don't mark it as processed yet.
- yield self.pdu_actions.persist_received(pdu)
+ yield self.store.persist_event(pdu=pdu)
if not backfilled:
ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 3f07b5aa4a..78df9ac53e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from twisted.internet import defer
class BaseHandler(object):
@@ -26,3 +26,24 @@ class BaseHandler(object):
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.hs = hs
+
+
+class BaseRoomHandler(BaseHandler):
+
+ @defer.inlineCallbacks
+ def _on_new_room_event(self, event, snapshot, extra_destinations=[]):
+ snapshot.fill_out_prev_events(event)
+
+ store_id = yield self.store.persist_event(event)
+
+ destinations = set(extra_destinations)
+ # Send a PDU to all hosts who have joined the room.
+ destinations.update((yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )))
+ event.destinations = list(destinations)
+
+ self.notifier.on_new_room_event(event, store_id)
+
+ federation_handler = self.hs.get_handlers().federation_handler
+ yield federation_handler.handle_new_event(event, snapshot)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bfc1ab86f2..7253f56322 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,6 +20,9 @@ from ._base import BaseHandler
from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
from synapse.api.constants import Membership
from synapse.util.logutils import log_function
+from synapse.federation.pdu_codec import PduCodec
+
+from synapse.api.errors import AuthError
from twisted.internet import defer
@@ -30,8 +33,14 @@ logger = logging.getLogger(__name__)
class FederationHandler(BaseHandler):
+ """Handles events that originated from federation.
+ Responsible for:
+ a) handling received Pdus before handing them on as Events to the rest
+ of the home server (including auth and state conflict resoultion)
+ b) converting events that were produced by local clients that may need
+ to be sent to remote home servers.
+ """
- """Handles events that originated from federation."""
def __init__(self, hs):
super(FederationHandler, self).__init__(hs)
@@ -42,9 +51,67 @@ class FederationHandler(BaseHandler):
self.waiting_for_join_list = {}
+ self.store = hs.get_datastore()
+ self.replication_layer = hs.get_replication_layer()
+ self.state_handler = hs.get_state_handler()
+ # self.auth_handler = gs.get_auth_handler()
+ self.server_name = hs.hostname
+
+ self.lock_manager = hs.get_room_lock_manager()
+
+ self.replication_layer.set_handler(self)
+
+ self.pdu_codec = PduCodec(hs)
+
@log_function
@defer.inlineCallbacks
- def on_receive(self, event, is_new_state, backfilled):
+ def handle_new_event(self, event, snapshot):
+ """ Takes in an event from the client to server side, that has already
+ been authed and handled by the state module, and sends it to any
+ remote home servers that may be interested.
+
+ Args:
+ event
+ snapshot (.storage.Snapshot): THe snapshot the event happened after
+
+ Returns:
+ Deferred: Resolved when it has successfully been queued for
+ processing.
+ """
+
+ pdu = self.pdu_codec.pdu_from_event(event)
+
+ if not hasattr(pdu, "destinations") or not pdu.destinations:
+ pdu.destinations = []
+
+ yield self.replication_layer.send_pdu(pdu)
+
+ @log_function
+ def get_state_for_room(self, destination, room_id):
+ return self.replication_layer.get_state_for_context(
+ destination, room_id
+ )
+
+ @log_function
+ @defer.inlineCallbacks
+ def on_receive_pdu(self, pdu, backfilled):
+ """ Called by the ReplicationLayer when we have a new pdu. We need to
+ do auth checks and put it throught the StateHandler.
+ """
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ with (yield self.lock_manager.lock(pdu.context)):
+ if event.is_state and not backfilled:
+ is_new_state = yield self.state_handler.handle_new_state(
+ pdu
+ )
+ if not is_new_state:
+ return
+ else:
+ is_new_state = False
+ # TODO: Implement something in federation that allows us to
+ # respond to PDU.
+
if hasattr(event, "state_key") and not is_new_state:
logger.debug("Ignoring old state.")
return
@@ -86,8 +153,7 @@ class FederationHandler(BaseHandler):
if not room:
# Huh, let's try and get the current state
try:
- federation = self.hs.get_federation()
- yield federation.get_state_for_room(
+ yield self.get_state_for_room(
event.origin, event.room_id
)
@@ -119,11 +185,10 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id
)
-
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit):
- events = yield self.hs.get_federation().backfill(dest, room_id, limit)
+ events = yield self._backfill(dest, room_id, limit)
for event in events:
try:
@@ -133,10 +198,23 @@ class FederationHandler(BaseHandler):
defer.returnValue(events)
+ @defer.inlineCallbacks
+ def _backfill(self, dest, room_id, limit):
+ pdus = yield self.replication_layer.backfill(dest, room_id, limit)
+
+ if not pdus:
+ defer.returnValue([])
+
+ events = [
+ self.pdu_codec.event_from_pdu(pdu)
+ for pdu in pdus
+ ]
+
+ defer.returnValue(events)
+
@log_function
@defer.inlineCallbacks
- def do_invite_join(self, target_host, room_id, joinee, content):
- federation = self.hs.get_federation()
+ def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
@@ -146,7 +224,7 @@ class FederationHandler(BaseHandler):
# First get current state to see if we are already joined.
try:
- yield federation.get_state_for_room(target_host, room_id)
+ yield self.get_state_for_room(target_host, room_id)
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
@@ -166,7 +244,8 @@ class FederationHandler(BaseHandler):
new_event.destinations = [target_host]
- yield federation.handle_new_event(new_event)
+ snapshot.fill_out_prev_events(new_event)
+ yield self.handle_new_event(new_event, snapshot)
# TODO (erikj): Time out here.
d = defer.Deferred()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5a4569ac95..7b4b051888 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,14 +25,14 @@ from synapse.api.events.room import (
from synapse.api.streams.event import EventStream, EventsStreamData
from synapse.handlers.presence import PresenceStreamData
from synapse.util import stringutils
-from ._base import BaseHandler
+from ._base import BaseRoomHandler
import logging
logger = logging.getLogger(__name__)
-class MessageHandler(BaseHandler):
+class MessageHandler(BaseRoomHandler):
def __init__(self, hs):
super(MessageHandler, self).__init__(hs)
@@ -84,20 +84,12 @@ class MessageHandler(BaseHandler):
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
- with (yield self.room_lock.lock(event.room_id)):
- if not suppress_auth:
- yield self.auth.check(event, raises=True)
+ snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- # store message in db
- store_id = yield self.store.persist_event(event)
+ if not suppress_auth:
+ yield self.auth.check(event, snapshot, raises=True)
- event.destinations = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
-
- self.notifier.on_new_room_event(event, store_id)
-
- yield self.hs.get_federation().handle_new_event(event)
+ yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@@ -134,23 +126,16 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong.
"""
- with (yield self.room_lock.lock(event.room_id)):
- yield self.auth.check(event, raises=True)
+ snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- if stamp_event:
- event.content["hsob_ts"] = int(self.clock.time_msec())
+ yield self.auth.check(event, snapshot, raises=True)
- yield self.state_handler.handle_new_event(event)
-
- # store in db
- store_id = yield self.store.persist_event(event)
+ if stamp_event:
+ event.content["hsob_ts"] = int(self.clock.time_msec())
- event.destinations = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
- self.notifier.on_new_room_event(event, store_id)
+ yield self.state_handler.handle_new_event(event, snapshot)
- yield self.hs.get_federation().handle_new_event(event)
+ yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
@@ -219,18 +204,12 @@ class MessageHandler(BaseHandler):
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
- with (yield self.room_lock.lock(event.room_id)):
- yield self.auth.check(event, raises=True)
-
- # store message in db
- store_id = yield self.store.persist_event(event)
+ snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- event.destinations = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
- yield self.hs.get_federation().handle_new_event(event)
+ yield self.auth.check(event, snapshot, raises=True)
- self.notifier.on_new_room_event(event, store_id)
+ # store message in db
+ yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
@@ -312,7 +291,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret)
-class RoomCreationHandler(BaseHandler):
+class RoomCreationHandler(BaseRoomHandler):
@defer.inlineCallbacks
def create_room(self, user_id, room_id, config):
@@ -383,6 +362,13 @@ class RoomCreationHandler(BaseHandler):
content=config,
)
+ snapshot = yield self.store.snapshot_room(
+ room_id=room_id,
+ user_id=user_id,
+ state_type=RoomConfigEvent.TYPE,
+ state_key="",
+ )
+
if room_alias:
yield self.store.create_room_alias_association(
room_id=room_id,
@@ -390,10 +376,11 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname],
)
- yield self.state_handler.handle_new_event(config_event)
+ yield self.state_handler.handle_new_event(config_event, snapshot)
# store_id = persist...
- yield self.hs.get_federation().handle_new_event(config_event)
+ federation_handler = self.hs.get_handlers().federation_handler
+ yield federation_handler.handle_new_event(config_event, snapshot)
# self.notifier.on_new_room_event(event, store_id)
content = {"membership": Membership.JOIN}
@@ -418,7 +405,7 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result)
-class RoomMemberHandler(BaseHandler):
+class RoomMemberHandler(BaseRoomHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
@@ -529,6 +516,11 @@ class RoomMemberHandler(BaseHandler):
"""
target_user_id = event.state_key
+ snapshot = yield self.store.snapshot_room(
+ event.room_id, event.user_id,
+ RoomMemberEvent.TYPE, target_user_id
+ )
+ ## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member(
target_user_id, event.room_id
)
@@ -549,24 +541,22 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the
# invite/join dance.
if event.membership == Membership.JOIN:
- yield self._do_join(event, do_auth=do_auth)
+ yield self._do_join(event, snapshot, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
if do_auth:
- yield self.auth.check(event, raises=True)
+ yield self.auth.check(event, snapshot, raises=True)
- prev_state = yield self.store.get_room_member(
- target_user_id, event.room_id
- )
if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP.
defer.returnValue({})
return
- yield self.state_handler.handle_new_event(event)
+ yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
+ snapshot=snapshot,
)
defer.returnValue({"room_id": room_id})
@@ -596,12 +586,16 @@ class RoomMemberHandler(BaseHandler):
content=content,
)
- yield self._do_join(new_event, room_host=host, do_auth=True)
+ snapshot = yield self.store.snapshot_room(
+ room_id, joinee, RoomMemberEvent.TYPE, joinee
+ )
+
+ yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
- def _do_join(self, event, room_host=None, do_auth=True):
+ def _do_join(self, event, snapshot, room_host=None, do_auth=True):
joinee = self.hs.parse_userid(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id
@@ -623,6 +617,7 @@ class RoomMemberHandler(BaseHandler):
elif room_host:
should_do_dance = True
else:
+ # TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
joinee.to_string(), room_id
)
@@ -642,7 +637,7 @@ class RoomMemberHandler(BaseHandler):
if should_do_dance:
handler = self.hs.get_handlers().federation_handler
have_joined = yield handler.do_invite_join(
- room_host, room_id, event.user_id, event.content
+ room_host, room_id, event.user_id, event.content, snapshot
)
# We want to do the _do_update inside the room lock.
@@ -650,12 +645,13 @@ class RoomMemberHandler(BaseHandler):
logger.debug("Doing normal join")
if do_auth:
- yield self.auth.check(event, raises=True)
+ yield self.auth.check(event, snapshot, raises=True)
- yield self.state_handler.handle_new_event(event)
+ yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
+ snapshot=snapshot,
)
user = self.hs.parse_userid(event.user_id)
@@ -699,39 +695,26 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms])
- @defer.inlineCallbacks
- def _do_local_membership_update(self, event, membership):
- # store membership
- store_id = yield self.store.persist_event(event)
-
- # Send a PDU to all hosts who have joined the room.
- destinations = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
+ def _do_local_membership_update(self, event, membership, snapshot):
+ destinations = []
# If we're inviting someone, then we should also send it to that
# HS.
target_user_id = event.state_key
if membership == Membership.INVITE:
- host = UserID.from_string(
- target_user_id, self.hs
- ).domain
+ host = UserID.from_string(target_user_id, self.hs).domain
destinations.append(host)
# If we are joining a remote HS, include that.
if membership == Membership.JOIN:
- host = UserID.from_string(
- target_user_id, self.hs
- ).domain
+ host = UserID.from_string(target_user_id, self.hs).domain
destinations.append(host)
- event.destinations = list(set(destinations))
-
- yield self.hs.get_federation().handle_new_event(event)
- self.notifier.on_new_room_event(event, store_id)
-
+ return self._on_new_room_event(
+ event, snapshot, extra_destinations=destinations
+ )
-class RoomListHandler(BaseHandler):
+class RoomListHandler(BaseRoomHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
diff --git a/synapse/server.py b/synapse/server.py
index 24f3a88103..94facf9d99 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -20,7 +20,6 @@
# Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication
-from synapse.federation.handler import FederationEventHandler
from synapse.api.events.factory import EventFactory
from synapse.api.notifier import Notifier
from synapse.api.auth import Auth
@@ -58,7 +57,6 @@ class BaseHomeServer(object):
'http_client',
'db_pool',
'persistence_service',
- 'federation',
'replication_layer',
'datastore',
'event_factory',
@@ -165,9 +163,6 @@ class HomeServer(BaseHomeServer):
def build_replication_layer(self):
return initialize_http_replication(self)
- def build_federation(self):
- return FederationEventHandler(self)
-
def build_datastore(self):
return DataStore(self)
diff --git a/synapse/state.py b/synapse/state.py
index ca8e1ca630..e1a1a159bb 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -45,7 +45,7 @@ class StateHandler(object):
@defer.inlineCallbacks
@log_function
- def handle_new_event(self, event):
+ def handle_new_event(self, event, snapshot):
""" Given an event this works out if a) we have sufficient power level
to update the state and b) works out what the prev_state should be.
@@ -70,25 +70,13 @@ class StateHandler(object):
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
- results = yield self.store.get_latest_pdus_in_context(
- event.room_id
- )
+ snapshot.fill_out_prev_events(event)
event.prev_events = [
- encode_event_id(p_id, origin) for p_id, origin, _ in results
- ]
- event.prev_events = [
e for e in event.prev_events if e != event.event_id
]
- if results:
- event.depth = max([int(v) for _, _, v in results]) + 1
- else:
- event.depth = 0
-
- current_state = yield self.store.get_current_state_pdu(
- key.context, key.type, key.state_key
- )
+ current_state = snapshot.prev_state_pdu
if current_state:
event.prev_state = encode_event_id(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 38ab03c45c..514d7eeb69 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -57,20 +57,22 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
@log_function
- def persist_event(self, event, backfilled=False):
- if event.type == RoomMemberEvent.TYPE:
- yield self._store_room_member(event)
- elif event.type == FeedbackEvent.TYPE:
- yield self._store_feedback(event)
-# elif event.type == RoomConfigEvent.TYPE:
-# yield self._store_room_config(event)
- elif event.type == RoomNameEvent.TYPE:
- yield self._store_room_name(event)
- elif event.type == RoomTopicEvent.TYPE:
- yield self._store_room_topic(event)
-
- ret = yield self._store_event(event, backfilled)
- defer.returnValue(ret)
+ def persist_event(self, event=None, backfilled=False, pdu=None):
+ stream_ordering = None
+ if backfilled:
+ if not self.min_token_deferred.called:
+ yield self.min_token_deferred
+ self.min_token -= 1
+ stream_ordering = self.min_token
+
+ latest = yield self._db_pool.runInteraction(
+ self._persist_pdu_event_txn,
+ pdu=pdu,
+ event=event,
+ backfilled=backfilled,
+ stream_ordering=stream_ordering,
+ )
+ defer.returnValue(latest)
@defer.inlineCallbacks
def get_event(self, event_id):
@@ -89,12 +91,42 @@ class DataStore(RoomMemberStore, RoomStore,
event = self._parse_event_from_row(events_dict)
defer.returnValue(event)
- @defer.inlineCallbacks
+ def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
+ backfilled=False, stream_ordering=None):
+ if pdu is not None:
+ self._persist_event_pdu_txn(txn, pdu)
+ if event is not None:
+ self._persist_event_txn(txn, event, backfilled, stream_ordering)
+
+ def _persist_event_pdu_txn(self, txn, pdu):
+ cols = dict(pdu.__dict__)
+ unrec_keys = dict(pdu.unrecognized_keys)
+ del cols["content"]
+ del cols["prev_pdus"]
+ cols["content_json"] = json.dumps(pdu.content)
+ cols["unrecognized_keys"] = json.dumps(unrec_keys)
+
+ logger.debug("Persisting: %s", repr(cols))
+
+ if pdu.is_state:
+ self._persist_state_txn(txn, pdu.prev_pdus, cols)
+ else:
+ self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
+
+ self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
+
@log_function
- def _store_event(self, event, backfilled):
- # FIXME (erikj): This should be removed when we start amalgamating
- # event and pdu storage
- yield self.hs.get_federation().fill_out_prev_events(event)
+ def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None):
+ if event.type == RoomMemberEvent.TYPE:
+ self._store_room_member_txn(txn, event)
+ elif event.type == FeedbackEvent.TYPE:
+ self._store_feedback_txn(txn,event)
+# elif event.type == RoomConfigEvent.TYPE:
+# self._store_room_config_txn(txn, event)
+ elif event.type == RoomNameEvent.TYPE:
+ self._store_room_name_txn(txn, event)
+ elif event.type == RoomTopicEvent.TYPE:
+ self._store_room_topic_txn(txn, event)
vals = {
"topological_ordering": event.depth,
@@ -105,17 +137,14 @@ class DataStore(RoomMemberStore, RoomStore,
"processed": True,
}
+ if stream_ordering is not None:
+ vals["stream_ordering"] = stream_ordering
+
if hasattr(event, "outlier"):
vals["outlier"] = event.outlier
else:
vals["outlier"] = False
- if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- vals["stream_ordering"] = self.min_token
-
unrec = {
k: v
for k, v in event.get_full_dict().items()
@@ -124,7 +153,7 @@ class DataStore(RoomMemberStore, RoomStore,
vals["unrecognized_keys"] = json.dumps(unrec)
try:
- yield self._simple_insert("events", vals)
+ self._simple_insert_txn(txn, "events", vals)
except:
logger.exception(
"Failed to persist, probably duplicate: %s",
@@ -143,9 +172,10 @@ class DataStore(RoomMemberStore, RoomStore,
if hasattr(event, "prev_state"):
vals["prev_state"] = event.prev_state
- yield self._simple_insert("state_events", vals)
+ self._simple_insert_txn(txn, "state_events", vals)
- yield self._simple_insert(
+ self._simple_insert_txn(
+ txn,
"current_state_events",
{
"event_id": event.event_id,
@@ -155,8 +185,7 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
- latest = yield self.get_room_events_max_id()
- defer.returnValue(latest)
+ return self._get_room_events_max_id_txn(txn)
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -192,6 +221,85 @@ class DataStore(RoomMemberStore, RoomStore,
defer.returnValue(self.min_token)
+ def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+ """Snapshot the room for an update by a user
+ Args:
+ room_id (synapse.types.RoomId): The room to snapshot.
+ user_id (synapse.types.UserId): The user to snapshot the room for.
+ state_type (str): Optional state type to snapshot.
+ state_key (str): Optional state key to snapshot.
+ Returns:
+ synapse.storage.Snapshot: A snapshot of the state of the room.
+ """
+ def _snapshot(txn):
+ membership_state = self._get_room_member(txn, user_id, room_id)
+ prev_pdus = self._get_latest_pdus_in_context(
+ txn, room_id
+ )
+ if state_type is not None and state_key is not None:
+ prev_state_pdu = self._get_current_state_pdu(
+ txn, room_id, state_type, state_key
+ )
+ else:
+ prev_state_pdu = None
+
+ return Snapshot(
+ store=self,
+ room_id=room_id,
+ user_id=user_id,
+ prev_pdus=prev_pdus,
+ membership_state=membership_state,
+ state_type=state_type,
+ state_key=state_key,
+ prev_state_pdu=prev_state_pdu,
+ )
+
+ return self._db_pool.runInteraction(_snapshot)
+
+
+class Snapshot(object):
+ """Snapshot of the state of a room
+ Args:
+ store (DataStore): The datastore.
+ room_id (RoomId): The room of the snapshot.
+ user_id (UserId): The user this snapshot is for.
+ prev_pdus (list): The list of PDU ids this snapshot is after.
+ membership_state (RoomMemberEvent): The current state of the user in
+ the room.
+ state_type (str, optional): State type captured by the snapshot
+ state_key (str, optional): State key captured by the snapshot
+ prev_state_pdu (PduEntry, optional): pdu id of
+ the previous value of the state type and key in the room.
+ """
+
+ def __init__(self, store, room_id, user_id, prev_pdus,
+ membership_state, state_type=None, state_key=None,
+ prev_state_pdu=None):
+ self.store = store
+ self.room_id = room_id
+ self.user_id = user_id
+ self.prev_pdus = prev_pdus
+ self.membership_state = membership_state
+ self.state_type = state_type
+ self.state_key = state_key
+ self.prev_state_pdu = prev_state_pdu
+
+ def fill_out_prev_events(self, event):
+ if hasattr(event, "prev_events"):
+ return
+
+ es = [
+ "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
+ ]
+
+ event.prev_events = [e for e in es if e != event.event_id]
+
+ if self.prev_pdus:
+ event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
+ else:
+ event.depth = 0
+
+
def schema_path(schema):
""" Get a filesystem path for the named database schema
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 75aab2d3b9..33d56f47ce 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -86,16 +86,18 @@ class SQLBaseStore(object):
table : string giving the table name
values : dict of new column names and values for them
"""
+ return self._db_pool.runInteraction(
+ self._simple_insert_txn, table, values,
+ )
+
+ def _simple_insert_txn(self, txn, table, values):
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in values),
", ".join("?" for k in values)
)
-
- def func(txn):
- txn.execute(sql, values.values())
- return txn.lastrowid
- return self._db_pool.runInteraction(func)
+ txn.execute(sql, values.values())
+ return txn.lastrowid
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False):
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
index 513b72d279..bac3dea955 100644
--- a/synapse/storage/feedback.py
+++ b/synapse/storage/feedback.py
@@ -20,8 +20,8 @@ from ._base import SQLBaseStore
class FeedbackStore(SQLBaseStore):
- def _store_feedback(self, event):
- return self._simple_insert("feedback", {
+ def _store_feedback_txn(self, txn, event):
+ self._simple_insert_txn(txn, "feedback", {
"event_id": event.event_id,
"feedback_type": event.content["type"],
"room_id": event.room_id,
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 7655f43ede..9fd44f2454 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -114,7 +114,7 @@ class PduStore(SQLBaseStore):
return self._get_pdu_tuples(txn, res)
- def persist_pdu(self, prev_pdus, **cols):
+ def _persist_pdu_txn(self, txn, prev_pdus, cols):
"""Inserts a (non-state) PDU into the database.
Args:
@@ -122,11 +122,6 @@ class PduStore(SQLBaseStore):
prev_pdus (list)
**cols: The columns to insert into the PdusTable.
"""
- return self._db_pool.runInteraction(
- self._persist_pdu, prev_pdus, cols
- )
-
- def _persist_pdu(self, txn, prev_pdus, cols):
entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
@@ -262,7 +257,7 @@ class PduStore(SQLBaseStore):
return row[0] if row else None
- def update_min_depth_for_context(self, context, depth):
+ def _update_min_depth_for_context_txn(self, txn, context, depth):
"""Update the minimum `depth` of the given context, which is the line
on which we stop backfilling backwards.
@@ -270,11 +265,6 @@ class PduStore(SQLBaseStore):
context (str)
depth (int)
"""
- return self._db_pool.runInteraction(
- self._update_min_depth_for_context, context, depth
- )
-
- def _update_min_depth_for_context(self, txn, context, depth):
min_depth = self._get_min_depth_interaction(txn, context)
do_insert = depth < min_depth if min_depth else True
@@ -286,7 +276,7 @@ class PduStore(SQLBaseStore):
(context, depth)
)
- def get_latest_pdus_in_context(self, context):
+ def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
key
@@ -295,11 +285,6 @@ class PduStore(SQLBaseStore):
txn
context
"""
- return self._db_pool.runInteraction(
- self._get_latest_pdus_in_context, context
- )
-
- def _get_latest_pdus_in_context(self, txn, context):
query = (
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
@@ -485,7 +470,7 @@ class StatePduStore(SQLBaseStore):
"""A collection of queries for handling state PDUs.
"""
- def persist_state(self, prev_pdus, **cols):
+ def _persist_state_txn(self, txn, prev_pdus, cols):
"""Inserts a state PDU into the database
Args:
@@ -493,12 +478,6 @@ class StatePduStore(SQLBaseStore):
prev_pdus (list)
**cols: The columns to insert into the PdusTable and StatePdusTable
"""
-
- return self._db_pool.runInteraction(
- self._persist_state, prev_pdus, cols
- )
-
- def _persist_state(self, txn, prev_pdus, cols):
pdu_entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index a5751005ef..d1f1a232f8 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -129,8 +129,9 @@ class RoomStore(SQLBaseStore):
defer.returnValue(ret)
- def _store_room_topic(self, event):
- return self._simple_insert(
+ def _store_room_topic_txn(self, txn, event):
+ self._simple_insert_txn(
+ txn,
"topics",
{
"event_id": event.event_id,
@@ -139,8 +140,9 @@ class RoomStore(SQLBaseStore):
}
)
- def _store_room_name(self, event):
- return self._simple_insert(
+ def _store_room_name_txn(self, txn, event):
+ self._simple_insert_txn(
+ txn,
"room_names",
{
"event_id": event.event_id,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4ad37af0f3..2746126e85 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -26,14 +26,14 @@ logger = logging.getLogger(__name__)
class RoomMemberStore(SQLBaseStore):
- @defer.inlineCallbacks
- def _store_room_member(self, event):
+ def _store_room_member_txn(self, txn, event):
"""Store a room member in the database.
"""
target_user_id = event.state_key
domain = self.hs.parse_userid(target_user_id).domain
- yield self._simple_insert(
+ self._simple_insert_txn(
+ txn,
"room_memberships",
{
"event_id": event.event_id,
@@ -50,13 +50,13 @@ class RoomMemberStore(SQLBaseStore):
"INSERT OR IGNORE INTO room_hosts (room_id, host) "
"VALUES (?, ?)"
)
- yield self._execute(None, sql, event.room_id, domain)
+ txn.execute(sql, (event.room_id, domain))
else:
sql = (
"DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
)
- yield self._execute(None, sql, event.room_id, domain)
+ txn.execute(sql, (event.room_id, domain))
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
@@ -75,6 +75,24 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(rows[0] if rows else None)
+ def _get_room_member(self, txn, user_id, room_id):
+ sql = (
+ "SELECT e.* FROM events as e"
+ " INNER JOIN room_memberships as m"
+ " ON e.event_id = m.event_id"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id"
+ " WHERE m.user_id = ? and e.room_id = ?"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = self.cursor_to_dict(txn)
+ if rows:
+ return self._parse_event_from_row(rows[0])
+ else:
+ return None
+
+
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index cae80563b4..ac887e2957 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -281,17 +281,20 @@ class StreamStore(SQLBaseStore):
)
)
- @defer.inlineCallbacks
def get_room_events_max_id(self):
- res = yield self._execute_and_decode(
+ return self._db_pool.runInteraction(self._get_room_events_max_id_txn)
+
+ def _get_room_events_max_id_txn(self, txn):
+ txn.execute(
"SELECT MAX(stream_ordering) as m FROM events"
)
+ res = self.cursor_to_dict(txn)
+
logger.debug("get_room_events_max_id: %s", res)
if not res or not res[0] or not res[0]["m"]:
- defer.returnValue("s1")
- return
+ return "s1"
key = res[0]["m"] + 1
- defer.returnValue("s%d" % (key,))
+ return "s%d" % (key,)
diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py
index 58590e4fcd..938b57bec9 100644
--- a/tests/federation/test_federation.py
+++ b/tests/federation/test_federation.py
@@ -58,7 +58,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_persistence = Mock(spec=[
"get_current_state_for_context",
"get_pdu",
- "persist_pdu",
+ "persist_event",
"update_min_depth_for_context",
"prep_send_transaction",
"delivered_txn",
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index a92d825f49..5ad40e484c 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -22,8 +22,9 @@ from synapse.api.events.room import (
from synapse.api.constants import Membership
from synapse.handlers.federation import FederationHandler
from synapse.server import HomeServer
+from synapse.federation.units import Pdu
-from mock import NonCallableMock
+from mock import NonCallableMock, ANY
import logging
@@ -60,37 +61,42 @@ class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_msg(self):
- event = self.hs.get_event_factory().create_event(
- etype=MessageEvent.TYPE,
- msg_id="bob",
- room_id="foo",
+ pdu = Pdu(
+ pdu_type=MessageEvent.TYPE,
+ context="foo",
content={"msgtype": u"fooo"},
+ ts=0,
+ pdu_id="a",
+ origin="b",
)
store_id = "ASD"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(True)
- yield self.handlers.federation_handler.on_receive(event, False, False)
+ yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
- self.datastore.persist_event.assert_called_once_with(event, False)
+ self.datastore.persist_event.assert_called_once_with(ANY, False)
self.notifier.on_new_room_event.assert_called_once_with(
- event, store_id)
+ ANY, store_id)
@defer.inlineCallbacks
def test_invite_join_target_this(self):
room_id = "foo"
user_id = "@bob:red"
- event = self.hs.get_event_factory().create_event(
- etype=InviteJoinEvent.TYPE,
+ pdu = Pdu(
+ pdu_type=InviteJoinEvent.TYPE,
user_id=user_id,
target_host=self.hostname,
- room_id=room_id,
+ context=room_id,
content={},
+ ts=0,
+ pdu_id="a",
+ origin="b",
)
- yield self.handlers.federation_handler.on_receive(event, False, False)
+ yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
mem_handler = self.handlers.room_member_handler
self.assertEquals(1, mem_handler.change_membership.call_count)
@@ -107,15 +113,18 @@ class FederationTestCase(unittest.TestCase):
room_id = "foo"
user_id = "@bob:red"
- event = self.hs.get_event_factory().create_event(
- etype=InviteJoinEvent.TYPE,
+ pdu = Pdu(
+ pdu_type=InviteJoinEvent.TYPE,
user_id=user_id,
- target_user_id="@red:not%s" % self.hostname,
- room_id=room_id,
+ state_key="@red:not%s" % self.hostname,
+ context=room_id,
content={},
+ ts=0,
+ pdu_id="a",
+ origin="b",
)
- yield self.handlers.federation_handler.on_receive(event, False, False)
+ yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
mem_handler = self.handlers.room_member_handler
self.assertEquals(0, mem_handler.change_membership.call_count)
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index fddab8f74f..a84dbcc471 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room_member",
"get_room",
"store_room",
+ "snapshot_room",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@@ -52,29 +53,36 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"profile_handler",
+ "federation_handler",
]),
auth=NonCallableMock(spec_set=["check"]),
- federation=NonCallableMock(spec_set=[
- "handle_new_event",
- "get_state_for_room",
- ]),
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
)
+ self.federation = NonCallableMock(spec_set=[
+ "handle_new_event",
+ "get_state_for_room",
+ ])
+
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
- self.federation = hs.get_federation()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.hs = hs
+ self.handlers.federation_handler = self.federation
+
self.distributor.declare("collect_presencelike_data")
self.handlers.room_member_handler = RoomMemberHandler(self.hs)
self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler
+ self.snapshot = Mock()
+ self.datastore.snapshot_room.return_value = self.snapshot
+
+
@defer.inlineCallbacks
def test_invite(self):
room_id = "!foo:red"
@@ -104,8 +112,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
# Actual invocation
yield self.room_member_handler.change_membership(event)
- self.state_handler.handle_new_event.assert_called_once_with(event)
- self.federation.handle_new_event.assert_called_once_with(event)
+ self.state_handler.handle_new_event.assert_called_once_with(
+ event, self.snapshot,
+ )
+ self.federation.handle_new_event.assert_called_once_with(
+ event, self.snapshot,
+ )
self.assertEquals(
set(["blue", "red", "green"]),
@@ -116,7 +128,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event
)
self.notifier.on_new_room_event.assert_called_once_with(
- event, store_id)
+ event, store_id
+ )
self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called)
@@ -148,6 +161,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.datastore.get_joined_hosts_for_room.side_effect = get_joined
+
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
@@ -163,8 +177,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
# Actual invocation
yield self.room_member_handler.change_membership(event)
- self.state_handler.handle_new_event.assert_called_once_with(event)
- self.federation.handle_new_event.assert_called_once_with(event)
+ self.state_handler.handle_new_event.assert_called_once_with(
+ event, self.snapshot
+ )
+ self.federation.handle_new_event.assert_called_once_with(
+ event, self.snapshot
+ )
self.assertEquals(
set(["red", "green"]),
@@ -312,27 +330,31 @@ class RoomCreationTest(unittest.TestCase):
db_pool=None,
datastore=NonCallableMock(spec_set=[
"store_room",
+ "snapshot_room",
]),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_creation_handler",
"room_member_handler",
+ "federation_handler",
]),
auth=NonCallableMock(spec_set=["check"]),
- federation=NonCallableMock(spec_set=[
- "handle_new_event",
- ]),
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
)
+ self.federation = NonCallableMock(spec_set=[
+ "handle_new_event",
+ ])
+
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
- self.federation = hs.get_federation()
self.state_handler = hs.get_state_handler()
self.hs = hs
+ self.handlers.federation_handler = self.federation
+
self.handlers.room_creation_handler = RoomCreationHandler(self.hs)
self.room_creation_handler = self.handlers.room_creation_handler
diff --git a/tests/rest/test_events.py b/tests/rest/test_events.py
index 94ad8910e3..3099a24e8c 100644
--- a/tests/rest/test_events.py
+++ b/tests/rest/test_events.py
@@ -128,9 +128,9 @@ class EventStreamPermissionsTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
replication_layer=Mock(),
state_handler=state_handler,
+ datastore=MemoryDataStore(),
persistence_service=persistence_service,
clock=Mock(spec=[
"call_later",
@@ -139,9 +139,10 @@ class EventStreamPermissionsTestCase(RestTestCase):
]),
)
+ hs.get_handlers().federation_handler = Mock()
+
hs.get_clock().time_msec.return_value = 1000000
- hs.datastore = MemoryDataStore()
synapse.rest.register.register_servlets(hs, self.mock_resource)
synapse.rest.events.register_servlets(hs, self.mock_resource)
synapse.rest.room.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/test_rooms.py b/tests/rest/test_rooms.py
index 589b434446..914dc28f53 100644
--- a/tests/rest/test_rooms.py
+++ b/tests/rest/test_rooms.py
@@ -54,12 +54,12 @@ class RoomPermissionsTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
)
+ hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None):
return hs.parse_userid(self.auth_user_id)
@@ -401,12 +401,12 @@ class RoomsMemberListTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
)
+ hs.get_handlers().federation_handler = Mock()
self.auth_user_id = self.user_id
@@ -479,12 +479,12 @@ class RoomsCreateTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
)
+ hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None):
return hs.parse_userid(self.auth_user_id)
@@ -569,12 +569,12 @@ class RoomTopicTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
)
+ hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None):
return hs.parse_userid(self.auth_user_id)
@@ -672,12 +672,12 @@ class RoomMemberStateTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
)
+ hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None):
return hs.parse_userid(self.auth_user_id)
@@ -797,12 +797,12 @@ class RoomMessagesTestCase(RestTestCase):
"test",
db_pool=None,
http_client=None,
- federation=Mock(),
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
)
+ hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None):
return hs.parse_userid(self.auth_user_id)
diff --git a/tests/test_state.py b/tests/test_state.py
index e64d15a3a2..58fd0bf3be 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -243,21 +243,24 @@ class StateTestCase(unittest.TestCase):
state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20)
- tup = ("pdu_id", "origin.com", 5)
- pdus = [tup]
+ snapshot = Mock()
+ snapshot.prev_state_pdu = state_pdu
+ event_id = "pdu_id@origin.com"
- self.persistence.get_latest_pdus_in_context.return_value = pdus
- self.persistence.get_current_state_pdu.return_value = state_pdu
+ def fill_out_prev_events(event):
+ event.prev_events = [event_id]
+ event.depth = 6
+ snapshot.fill_out_prev_events = fill_out_prev_events
- yield self.state.handle_new_event(event)
+ yield self.state.handle_new_event(event, snapshot)
- self.assertLess(tup[2], event.depth)
+ self.assertLess(5, event.depth)
self.assertEquals(1, len(event.prev_events))
prev_id = event.prev_events[0]
- self.assertEqual(encode_event_id(tup[0], tup[1]), prev_id)
+ self.assertEqual(event_id, prev_id)
self.assertEqual(
encode_event_id(state_pdu.pdu_id, state_pdu.origin),
diff --git a/tests/utils.py b/tests/utils.py
index f40cbce51d..6666b06931 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -127,6 +127,15 @@ class MemoryDataStore(object):
self.current_state = {}
self.events = []
+ class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
+ def fill_out_prev_events(self, event):
+ pass
+
+ def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+ return self.Snapshot(
+ room_id, user_id, self.get_room_member(user_id, room_id)
+ )
+
def register(self, user_id, token, password_hash):
if user_id in self.tokens_to_users.values():
raise StoreError(400, "User in use.")
|