summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/federation_client.py15
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/handlers/message.py33
-rw-r--r--synapse/storage/_base.py26
-rw-r--r--synapse/storage/events.py125
-rw-r--r--synapse/storage/stream.py2
-rw-r--r--tests/storage/test_base.py3
7 files changed, 122 insertions, 84 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 904c7c0945..c255df1bbb 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -222,7 +222,7 @@ class FederationClient(FederationBase):
                         for p in transaction_data["pdus"]
                     ]
 
-                    if pdu_list:
+                    if pdu_list and pdu_list[0]:
                         pdu = pdu_list[0]
 
                         # Check signatures are correct.
@@ -255,7 +255,7 @@ class FederationClient(FederationBase):
                 )
                 continue
 
-        if self._get_pdu_cache is not None:
+        if self._get_pdu_cache is not None and pdu:
             self._get_pdu_cache[event_id] = pdu
 
         defer.returnValue(pdu)
@@ -475,6 +475,9 @@ class FederationClient(FederationBase):
             limit (int): Maximum number of events to return.
             min_depth (int): Minimum depth of events tor return.
         """
+        logger.debug("get_missing_events: latest_events: %r", latest_events)
+        logger.debug("get_missing_events: earliest_events_ids: %r", earliest_events_ids)
+
         try:
             content = yield self.transport_layer.get_missing_events(
                 destination=destination,
@@ -485,6 +488,8 @@ class FederationClient(FederationBase):
                 min_depth=min_depth,
             )
 
+            logger.debug("get_missing_events: Got content: %r", content)
+
             events = [
                 self.event_from_pdu_json(e)
                 for e in content.get("events", [])
@@ -494,6 +499,8 @@ class FederationClient(FederationBase):
                 destination, events, outlier=False
             )
 
+            logger.debug("get_missing_events: signed_events: %r", signed_events)
+
             have_gotten_all_from_destination = True
         except HttpResponseException as e:
             if not e.code == 400:
@@ -518,6 +525,8 @@ class FederationClient(FederationBase):
             # Are we missing any?
 
             seen_events = set(earliest_events_ids)
+
+            logger.debug("get_missing_events: signed_events2: %r", signed_events)
             seen_events.update(e.event_id for e in signed_events)
 
             missing_events = {}
@@ -561,7 +570,7 @@ class FederationClient(FederationBase):
 
             res = yield defer.DeferredList(deferreds, consumeErrors=True)
             for (result, val), (e_id, _) in zip(res, ordered_missing):
-                if result:
+                if result and val:
                     signed_events.append(val)
                 else:
                     failed_to_fetch.add(e_id)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index cd79e23f4b..2c6488dd1b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -415,6 +415,8 @@ class FederationServer(FederationBase):
                 pdu.internal_metadata.outlier = True
             elif min_depth and pdu.depth > min_depth:
                 if get_missing and prevs - seen:
+                    logger.debug("We're missing: %r", prevs-seen)
+
                     latest = yield self.store.get_latest_event_ids_in_room(
                         pdu.room_id
                     )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 867fdbefb0..6a1b25d112 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -303,18 +303,27 @@ class MessageHandler(BaseHandler):
             if event.membership != Membership.JOIN:
                 return
             try:
-                (messages, token), current_state = yield defer.gatherResults(
-                    [
-                        self.store.get_recent_events_for_room(
-                            event.room_id,
-                            limit=limit,
-                            end_token=now_token.room_key,
-                        ),
-                        self.state_handler.get_current_state(
-                            event.room_id
-                        ),
-                    ]
-                ).addErrback(unwrapFirstError)
+                # (messages, token), current_state = yield defer.gatherResults(
+                #     [
+                #         self.store.get_recent_events_for_room(
+                #             event.room_id,
+                #             limit=limit,
+                #             end_token=now_token.room_key,
+                #         ),
+                #         self.state_handler.get_current_state(
+                #             event.room_id
+                #         ),
+                #     ]
+                # ).addErrback(unwrapFirstError)
+
+                messages, token = yield self.store.get_recent_events_for_room(
+                    event.room_id,
+                    limit=limit,
+                    end_token=now_token.room_key,
+                )
+                current_state = yield self.state_handler.get_current_state(
+                    event.room_id
+                )
 
                 start_token = now_token.copy_and_replace("room_key", token[0])
                 end_token = now_token.copy_and_replace("room_key", token[1])
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ceff99c16d..0df1b46edc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -301,10 +301,12 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
-        self._event_fetch_lock = threading.Condition()
+        self._event_fetch_lock = threading.Lock()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
+        self._pending_ds = []
+
         self.database_engine = hs.database_engine
 
         self._stream_id_gen = StreamIdGenerator()
@@ -344,8 +346,7 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
-    @contextlib.contextmanager
-    def _new_transaction(self, conn, desc, after_callbacks):
+    def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
         start = time.time() * 1000
         txn_id = self._TXN_ID
 
@@ -366,6 +367,9 @@ class SQLBaseStore(object):
                     txn = LoggingTransaction(
                         txn, name, self.database_engine, after_callbacks
                     )
+                    r = func(txn, *args, **kwargs)
+                    conn.commit()
+                    return r
                 except self.database_engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
@@ -398,17 +402,6 @@ class SQLBaseStore(object):
                                 )
                             continue
                     raise
-
-                try:
-                    yield txn
-                    conn.commit()
-                    return
-                except:
-                    try:
-                        conn.rollback()
-                    except:
-                        pass
-                    raise
         except Exception as e:
             logger.debug("[TXN FAIL] {%s} %s", name, e)
             raise
@@ -440,8 +433,9 @@ class SQLBaseStore(object):
                     conn.reconnect()
 
                 current_context.copy_to(context)
-                with self._new_transaction(conn, desc, after_callbacks) as txn:
-                    return func(txn, *args, **kwargs)
+                return self._new_transaction(
+                    conn, desc, after_callbacks, func, *args, **kwargs
+                )
 
         result = yield preserve_context_over_fn(
             self._db_pool.runWithConnection,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b4abd83260..260bdf0ec4 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -420,12 +420,14 @@ class EventsStore(SQLBaseStore):
             ])
 
         if not txn:
+            logger.debug("enqueue before")
             missing_events = yield self._enqueue_events(
                 missing_events_ids,
                 check_redacted=check_redacted,
                 get_prev_content=get_prev_content,
                 allow_rejected=allow_rejected,
             )
+            logger.debug("enqueue after")
         else:
             missing_events = self._fetch_events_txn(
                 txn,
@@ -498,41 +500,39 @@ class EventsStore(SQLBaseStore):
             allow_rejected=allow_rejected,
         ))
 
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events, check_redacted=True,
-                        get_prev_content=False, allow_rejected=False):
-        if not events:
-            defer.returnValue({})
-
-        def do_fetch(conn):
-            event_list = []
+    def _do_fetch(self, conn):
+        event_list = []
+        try:
             while True:
-                try:
-                    with self._event_fetch_lock:
-                        i = 0
-                        while not self._event_fetch_list:
-                            self._event_fetch_ongoing -= 1
-                            return
-
-                        event_list = self._event_fetch_list
-                        self._event_fetch_list = []
-
-                    event_id_lists = zip(*event_list)[0]
-                    event_ids = [
-                        item for sublist in event_id_lists for item in sublist
-                    ]
-
-                    with self._new_transaction(conn, "do_fetch", []) as txn:
-                        rows = self._fetch_event_rows(txn, event_ids)
-
-                    row_dict = {
-                        r["event_id"]: r
-                        for r in rows
-                    }
+                logger.debug("do_fetch getting lock")
+                with self._event_fetch_lock:
+                    logger.debug("do_fetch go lock: %r", self._event_fetch_list)
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+                    if not event_list:
+                        self._event_fetch_ongoing -= 1
+                        return
+
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+
+                rows = self._new_transaction(
+                    conn, "do_fetch", [], self._fetch_event_rows, event_ids
+                )
 
-                    for ids, d in event_list:
-                        def fire():
-                            if not d.called:
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                logger.debug("do_fetch got events: %r", row_dict.keys())
+
+                def fire(evs):
+                    for ids, d in evs:
+                        if not d.called:
+                            try:
                                 d.callback(
                                     [
                                         row_dict[i]
@@ -540,32 +540,51 @@ class EventsStore(SQLBaseStore):
                                         if i in row_dict
                                     ]
                                 )
-                        reactor.callFromThread(fire)
-                except Exception as e:
-                    logger.exception("do_fetch")
-                    for _, d in event_list:
-                        if not d.called:
-                            reactor.callFromThread(d.errback, e)
+                            except:
+                                logger.exception("Failed to callback")
+                reactor.callFromThread(fire, event_list)
+        except Exception as e:
+            logger.exception("do_fetch")
 
-                    with self._event_fetch_lock:
-                        self._event_fetch_ongoing -= 1
-                        return
+            def fire(evs):
+                for _, d in evs:
+                    if not d.called:
+                        d.errback(e)
+
+            if event_list:
+                reactor.callFromThread(fire, event_list)
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events, check_redacted=True,
+                        get_prev_content=False, allow_rejected=False):
+        if not events:
+            defer.returnValue({})
 
         events_d = defer.Deferred()
-        with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, events_d)
-            )
+        try:
+            logger.debug("enqueueueueue getting lock")
+            with self._event_fetch_lock:
+                logger.debug("enqueue go lock")
+                self._event_fetch_list.append(
+                    (events, events_d)
+                )
 
-            self._event_fetch_lock.notify_all()
+                self._event_fetch_ongoing += 1
 
-            # if self._event_fetch_ongoing < 5:
-            self._event_fetch_ongoing += 1
             self.runWithConnection(
-                do_fetch
+                self._do_fetch
             )
 
-        rows = yield events_d
+        except Exception as e:
+            if not events_d.called:
+                events_d.errback(e)
+
+        logger.debug("events_d before")
+        try:
+            rows = yield events_d
+        except:
+            logger.exception("events_d")
+        logger.debug("events_d after")
 
         res = yield defer.gatherResults(
             [
@@ -580,6 +599,7 @@ class EventsStore(SQLBaseStore):
             ],
             consumeErrors=True
         )
+        logger.debug("gatherResults after")
 
         defer.returnValue({
             e.event_id: e
@@ -639,7 +659,8 @@ class EventsStore(SQLBaseStore):
                     rejected_reason=row["rejects"],
                 )
                 for row in rows
-            ]
+            ],
+            consumeErrors=True,
         )
 
         defer.returnValue({
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d16b57c515..af45fc5619 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -357,10 +357,12 @@ class StreamStore(SQLBaseStore):
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
+        logger.debug("stream before")
         events = yield self._get_events(
             [r["event_id"] for r in rows],
             get_prev_content=True
         )
+        logger.debug("stream after")
 
         self._set_before_and_after(events, rows)
 
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 8c348ecc95..8573f18b55 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -33,8 +33,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def setUp(self):
         self.db_pool = Mock(spec=["runInteraction"])
         self.mock_txn = Mock()
-        self.mock_conn = Mock(spec_set=["cursor"])
+        self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
         self.mock_conn.cursor.return_value = self.mock_txn
+        self.mock_conn.rollback.return_value = None
         # Our fake runInteraction just runs synchronously inline
 
         def runInteraction(func, *args, **kwargs):