summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py19
-rw-r--r--synapse/federation/replication.py7
-rw-r--r--synapse/handlers/federation.py12
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/storage/__init__.py14
-rw-r--r--synapse/storage/state.py9
-rw-r--r--tests/handlers/test_room.py22
7 files changed, 54 insertions, 39 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 6c2d3db26e..87f19a96d6 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -36,6 +36,7 @@ class Auth(object):
     def __init__(self, hs):
         self.hs = hs
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
 
     def check(self, event, raises=False):
         """ Checks if this event is correctly authed.
@@ -90,7 +91,7 @@ class Auth(object):
             )
             logger.info("Denying! %s", event)
             if raises:
-                raise e
+                raise
 
         return False
 
@@ -109,9 +110,21 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def check_host_in_room(self, room_id, host):
-        joined_hosts = yield self.store.get_joined_hosts_for_room(room_id)
+        curr_state = yield self.state.get_current_state(room_id)
+
+        for event in curr_state:
+            if event.type == RoomMemberEvent.TYPE:
+                try:
+                    if self.hs.parse_userid(event.state_key).domain != host:
+                        continue
+                except:
+                    logger.warn("state_key not user_id: %s", event.state_key)
+                    continue
+
+                if event.content["membership"] == Membership.JOIN:
+                    defer.returnValue(True)
 
-        defer.returnValue(host in joined_hosts)
+        defer.returnValue(False)
 
     def check_event_sender_in_room(self, event):
         key = (RoomMemberEvent.TYPE, event.user_id, )
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 5c625ddabf..beec17e386 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -267,8 +267,6 @@ class ReplicationLayer(object):
         transaction = Transaction(**transaction_data)
 
         pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
-        for pdu in pdus:
-            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdus)
 
@@ -452,15 +450,12 @@ class ReplicationLayer(object):
         )
 
         logger.debug("Got content: %s", content)
+
         state = [Pdu(outlier=True, **p) for p in content.get("state", [])]
-        for pdu in state:
-            yield self._handle_new_pdu(destination, pdu)
 
         auth_chain = [
             Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
         ]
-        for pdu in auth_chain:
-            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(state)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d8d5730b65..99655c8bb0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -229,12 +229,6 @@ class FederationHandler(BaseHandler):
     @log_function
     @defer.inlineCallbacks
     def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
-        hosts = yield self.store.get_joined_hosts_for_room(room_id)
-        if self.hs.hostname in hosts:
-            # We are already in the room.
-            logger.debug("We're already in the room apparently")
-            defer.returnValue(False)
-
         pdu = yield self.replication_layer.make_join(
             target_host,
             room_id,
@@ -268,7 +262,7 @@ class FederationHandler(BaseHandler):
 
             logger.debug("do_invite_join state: %s", state)
 
-            is_new_state = yield self.state_handler.annotate_event_with_state(
+            yield self.state_handler.annotate_event_with_state(
                 event,
                 old_state=state
             )
@@ -296,13 +290,13 @@ class FederationHandler(BaseHandler):
                 yield self.store.persist_event(
                     e,
                     backfilled=False,
-                    is_new_state=False
+                    is_new_state=True
                 )
 
             yield self.store.persist_event(
                 event,
                 backfilled=False,
-                is_new_state=is_new_state
+                is_new_state=True
             )
         finally:
             room_queue = self.room_queues[room_id]
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3642fcfc6d..825957f721 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -24,6 +24,7 @@ from synapse.api.events.room import (
     RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
 )
 from synapse.util import stringutils
+from synapse.util.async import run_on_reactor
 from ._base import BaseHandler
 
 import logging
@@ -432,9 +433,12 @@ class RoomMemberHandler(BaseHandler):
         # that we are allowed to join when we decide whether or not we
         # need to do the invite/join dance.
 
-        hosts = yield self.store.get_joined_hosts_for_room(room_id)
+        is_host_in_room = yield self.auth.check_host_in_room(
+            event.room_id,
+            self.hs.hostname
+        )
 
-        if self.hs.hostname in hosts:
+        if is_host_in_room:
             should_do_dance = False
         elif room_host:
             should_do_dance = True
@@ -517,6 +521,8 @@ class RoomMemberHandler(BaseHandler):
     @defer.inlineCallbacks
     def _do_local_membership_update(self, event, membership, snapshot,
                                     do_auth):
+        yield run_on_reactor()
+
         # If we're inviting someone, then we should also send it to that
         # HS.
         target_user_id = event.state_key
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4034437f6b..72290eb5a0 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -186,6 +186,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 "events",
                 vals,
                 or_replace=(not outlier),
+                or_ignore=bool(outlier),
             )
         except:
             logger.warn(
@@ -217,7 +218,12 @@ class DataStore(RoomMemberStore, RoomStore,
             if hasattr(event, "replaces_state"):
                 vals["prev_state"] = event.replaces_state
 
-            self._simple_insert_txn(txn, "state_events", vals)
+            self._simple_insert_txn(
+                txn,
+                "state_events",
+                vals,
+                or_replace=True,
+            )
 
             self._simple_insert_txn(
                 txn,
@@ -227,7 +233,8 @@ class DataStore(RoomMemberStore, RoomStore,
                     "room_id": event.room_id,
                     "type": event.type,
                     "state_key": event.state_key,
-                }
+                },
+                or_replace=True,
             )
 
             for e_id, h in event.prev_state:
@@ -252,7 +259,8 @@ class DataStore(RoomMemberStore, RoomStore,
                         "room_id": event.room_id,
                         "type": event.type,
                         "state_key": event.state_key,
-                    }
+                    },
+                    or_replace=True,
                 )
 
                 for prev_state_id, _ in event.prev_state:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 68975969f5..2f3a70b4e5 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -70,7 +70,8 @@ class StateStore(SQLBaseStore):
                 values={
                     "room_id": event.room_id,
                     "event_id": event.event_id,
-                }
+                },
+                or_ignore=True,
             )
 
             for state in event.state_events.values():
@@ -83,7 +84,8 @@ class StateStore(SQLBaseStore):
                         "type": state.type,
                         "state_key": state.state_key,
                         "event_id": state.event_id,
-                    }
+                    },
+                    or_ignore=True,
                 )
 
         self._simple_insert_txn(
@@ -92,5 +94,6 @@ class StateStore(SQLBaseStore):
             values={
                 "state_group": state_group,
                 "event_id": event.event_id,
-            }
+            },
+            or_replace=True,
         )
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index ee264e5ee2..cbe591ab90 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -44,7 +44,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             ]),
             datastore=NonCallableMock(spec_set=[
                 "persist_event",
-                "get_joined_hosts_for_room",
                 "get_room_member",
                 "get_room",
                 "store_room",
@@ -58,9 +57,14 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
                 "profile_handler",
                 "federation_handler",
             ]),
-            auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
+            auth=NonCallableMock(spec_set=[
+                "check",
+                "add_auth_events",
+                "check_host_in_room",
+            ]),
             state_handler=NonCallableMock(spec_set=[
                 "annotate_event_with_state",
+                "get_current_state",
             ]),
             config=self.mock_config,
         )
@@ -76,6 +80,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         self.notifier = hs.get_notifier()
         self.state_handler = hs.get_state_handler()
         self.distributor = hs.get_distributor()
+        self.auth = hs.get_auth()
         self.hs = hs
 
         self.handlers.federation_handler = self.federation
@@ -108,11 +113,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             content=content,
         )
 
-        joined = ["red", "green"]
-
-        self.datastore.get_joined_hosts_for_room.return_value = (
-            defer.succeed(joined)
-        )
+        self.auth.check_host_in_room.return_value = defer.succeed(True)
 
         store_id = "store_id_fooo"
         self.datastore.persist_event.return_value = defer.succeed(store_id)
@@ -164,12 +165,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             room_id=room_id,
         )
 
-        joined = ["red", "green"]
-
-        def get_joined(*args):
-            return defer.succeed(joined)
-
-        self.datastore.get_joined_hosts_for_room.side_effect = get_joined
+        self.auth.check_host_in_room.return_value = defer.succeed(True)
 
         store_id = "store_id_fooo"
         self.datastore.persist_event.return_value = defer.succeed(store_id)