summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/_base.py20
-rw-r--r--synapse/handlers/room.py65
2 files changed, 33 insertions, 52 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 3f07b5aa4a..32c0d6b8aa 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+from twisted.internet import defer
 
 class BaseHandler(object):
 
@@ -26,3 +26,21 @@ class BaseHandler(object):
         self.state_handler = hs.get_state_handler()
         self.distributor = hs.get_distributor()
         self.hs = hs
+
+
+class BaseRoomHandler(BaseHandler):
+
+    @defer.inlineCallbacks
+    def _on_new_room_event(self, event, snapshot, extra_destinations=[]):
+        store_id = yield self.store.persist_event(event)
+
+        destinations = set(extra_destinations)
+        # Send a PDU to all hosts who have joined the room.
+        destinations.update((yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )))
+        event.destinations = list(destinations)
+
+        self.notifier.on_new_room_event(event, store_id)
+
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a94cfaec2a..4797f8be0c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,14 +25,14 @@ from synapse.api.events.room import (
 from synapse.api.streams.event import EventStream, EventsStreamData
 from synapse.handlers.presence import PresenceStreamData
 from synapse.util import stringutils
-from ._base import BaseHandler
+from ._base import BaseRoomHandler
 
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class MessageHandler(BaseHandler):
+class MessageHandler(BaseRoomHandler):
 
     def __init__(self, hs):
         super(MessageHandler, self).__init__(hs)
@@ -89,16 +89,7 @@ class MessageHandler(BaseHandler):
         if not suppress_auth:
             yield self.auth.check(event, snapshot, raises=True)
 
-        # store message in db
-        store_id = yield self.store.persist_event(event)
-
-        event.destinations = yield self.store.get_joined_hosts_for_room(
-            event.room_id
-        )
-
-        self.notifier.on_new_room_event(event, store_id)
-
-        yield self.hs.get_federation().handle_new_event(event, snapshot)
+        yield self._on_new_room_event(event, snapshot)
 
     @defer.inlineCallbacks
     def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@@ -144,15 +135,7 @@ class MessageHandler(BaseHandler):
 
         yield self.state_handler.handle_new_event(event)
 
-        # store in db
-        store_id = yield self.store.persist_event(event)
-
-        event.destinations = yield self.store.get_joined_hosts_for_room(
-            event.room_id
-        )
-        self.notifier.on_new_room_event(event, store_id)
-
-        yield self.hs.get_federation().handle_new_event(event, snapshot)
+        yield self._on_new_room_event(event, snapshot)
 
     @defer.inlineCallbacks
     def get_room_data(self, user_id=None, room_id=None,
@@ -226,14 +209,7 @@ class MessageHandler(BaseHandler):
         yield self.auth.check(event, snapshot, raises=True)
 
         # store message in db
-        store_id = yield self.store.persist_event(event)
-
-        event.destinations = yield self.store.get_joined_hosts_for_room(
-            event.room_id
-        )
-        yield self.hs.get_federation().handle_new_event(event, snapshot)
-
-        self.notifier.on_new_room_event(event, store_id)
+        yield self._on_new_room_event(event, snapshot)
 
     @defer.inlineCallbacks
     def snapshot_all_rooms(self, user_id=None, pagin_config=None,
@@ -311,7 +287,7 @@ class MessageHandler(BaseHandler):
         defer.returnValue(ret)
 
 
-class RoomCreationHandler(BaseHandler):
+class RoomCreationHandler(BaseRoomHandler):
 
     @defer.inlineCallbacks
     def create_room(self, user_id, room_id, config):
@@ -417,7 +393,7 @@ class RoomCreationHandler(BaseHandler):
         defer.returnValue(result)
 
 
-class RoomMemberHandler(BaseHandler):
+class RoomMemberHandler(BaseRoomHandler):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
     #   API that takes ID strings and returns pagination chunks. These concerns
@@ -707,39 +683,26 @@ class RoomMemberHandler(BaseHandler):
 
         defer.returnValue([r.room_id for r in rooms])
 
-    @defer.inlineCallbacks
     def _do_local_membership_update(self, event, membership, snapshot):
-        # store membership
-        store_id = yield self.store.persist_event(event)
-
-        # Send a PDU to all hosts who have joined the room.
-        destinations = yield self.store.get_joined_hosts_for_room(
-            event.room_id
-        )
+        destinations = []
 
         # If we're inviting someone, then we should also send it to that
         # HS.
         target_user_id = event.state_key
         if membership == Membership.INVITE:
-            host = UserID.from_string(
-                target_user_id, self.hs
-            ).domain
+            host = UserID.from_string(target_user_id, self.hs).domain
             destinations.append(host)
 
         # If we are joining a remote HS, include that.
         if membership == Membership.JOIN:
-            host = UserID.from_string(
-                target_user_id, self.hs
-            ).domain
+            host = UserID.from_string(target_user_id, self.hs).domain
             destinations.append(host)
 
-        event.destinations = list(set(destinations))
-
-        yield self.hs.get_federation().handle_new_event(event, snapshot)
-        self.notifier.on_new_room_event(event, store_id)
-
+        return self._on_new_room_event(
+            event, snapshot, extra_destinations=destinations
+        )
 
-class RoomListHandler(BaseHandler):
+class RoomListHandler(BaseRoomHandler):
 
     @defer.inlineCallbacks
     def get_public_room_list(self):