summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/persistence.py30
-rw-r--r--synapse/storage/__init__.py32
-rw-r--r--synapse/storage/pdu.py22
-rw-r--r--tests/federation/test_federation.py2
4 files changed, 36 insertions, 50 deletions
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index e0e4de4e8c..76d37a0c52 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -25,7 +25,6 @@ from .units import Pdu
 
 from synapse.util.logutils import log_function
 
-import copy
 import json
 import logging
 
@@ -48,9 +47,8 @@ class PduActions(object):
         Returns:
             Deferred
         """
-        return self._persist(pdu)
+        return self.store.persist_event(pdu=pdu)
 
-    @defer.inlineCallbacks
     @log_function
     def persist_outgoing(self, pdu):
         """ Persists the given `Pdu` that this home server created.
@@ -58,9 +56,7 @@ class PduActions(object):
         Returns:
             Deferred
         """
-        ret = yield self._persist(pdu)
-
-        defer.returnValue(ret)
+        return self.store.persist_event(pdu=pdu)
 
     @log_function
     def mark_as_processed(self, pdu):
@@ -143,28 +139,6 @@ class PduActions(object):
             depth=pdu.depth
         )
 
-    @defer.inlineCallbacks
-    @log_function
-    def _persist(self, pdu):
-        kwargs = copy.copy(pdu.__dict__)
-        unrec_keys = copy.copy(pdu.unrecognized_keys)
-        del kwargs["content"]
-        kwargs["content_json"] = json.dumps(pdu.content)
-        kwargs["unrecognized_keys"] = json.dumps(unrec_keys)
-
-        logger.debug("Persisting: %s", repr(kwargs))
-
-        if pdu.is_state:
-            ret = yield self.store.persist_state(**kwargs)
-        else:
-            ret = yield self.store.persist_pdu(**kwargs)
-
-        yield self.store.update_min_depth_for_context(
-            pdu.context, pdu.depth
-        )
-
-        defer.returnValue(ret)
-
 
 class TransactionActions(object):
     """ Defines persistence actions that relate to handling Transactions.
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 5e52e9fecf..a726b7346b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -57,7 +57,7 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, backfilled=False):
+    def persist_event(self, event=None, backfilled=False, pdu=None):
         # FIXME (erikj): This should be removed when we start amalgamating
         # event and pdu storage
         yield self.hs.get_federation().fill_out_prev_events(event)
@@ -70,7 +70,11 @@ class DataStore(RoomMemberStore, RoomStore,
             stream_ordering = self.min_token
 
         latest = yield self._db_pool.runInteraction(
-            self._persist_event_txn, event, backfilled, stream_ordering
+            self._persist_pdu_event_txn,
+            pdu=pdu,
+            event=event,
+            backfilled=backfilled,
+            stream_ordering=stream_ordering,
         )
         defer.returnValue(latest)
 
@@ -92,6 +96,30 @@ class DataStore(RoomMemberStore, RoomStore,
         event = self._parse_event_from_row(events_dict)
         defer.returnValue(event)
 
+    def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
+                               backfilled=False, stream_ordering=None):
+        if pdu is not None:
+            self._persist_pdu_txn(txn, pdu)
+        if event is not None:
+            self._persist_event_txn(txn, event, backfilled, stream_ordering)
+
+    def _persist_pdu_txn(self, txn, pdu):
+        cols = dict(pdu.__dict__)
+        unrec_keys = dict(pdu.unrecognized_keys)
+        del cols["content"]
+        del cols["prev_pdus"]
+        cols["content_json"] = json.dumps(pdu.content)
+        cols["unrecognized_keys"] = json.dumps(unrec_keys)
+
+        logger.debug("Persisting: %s", repr(cols))
+
+        if pdu.is_state:
+            self._persist_state_txn(txn, pdu.prev_pdus, cols)
+        else:
+            self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
+
+        self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
+
     @log_function
     def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None):
         if event.type == RoomMemberEvent.TYPE:
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 7655f43ede..657295b1d7 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -114,7 +114,7 @@ class PduStore(SQLBaseStore):
 
         return self._get_pdu_tuples(txn, res)
 
-    def persist_pdu(self, prev_pdus, **cols):
+    def _persist_pdu_txn(self, txn, prev_pdus, cols):
         """Inserts a (non-state) PDU into the database.
 
         Args:
@@ -122,11 +122,6 @@ class PduStore(SQLBaseStore):
             prev_pdus (list)
             **cols: The columns to insert into the PdusTable.
         """
-        return self._db_pool.runInteraction(
-            self._persist_pdu, prev_pdus, cols
-        )
-
-    def _persist_pdu(self, txn, prev_pdus, cols):
         entry = PdusTable.EntryType(
             **{k: cols.get(k, None) for k in PdusTable.fields}
         )
@@ -262,7 +257,7 @@ class PduStore(SQLBaseStore):
 
         return row[0] if row else None
 
-    def update_min_depth_for_context(self, context, depth):
+    def _update_min_depth_for_context_txn(self, txn, context, depth):
         """Update the minimum `depth` of the given context, which is the line
         on which we stop backfilling backwards.
 
@@ -270,11 +265,6 @@ class PduStore(SQLBaseStore):
             context (str)
             depth (int)
         """
-        return self._db_pool.runInteraction(
-            self._update_min_depth_for_context, context, depth
-        )
-
-    def _update_min_depth_for_context(self, txn, context, depth):
         min_depth = self._get_min_depth_interaction(txn, context)
 
         do_insert = depth < min_depth if min_depth else True
@@ -485,7 +475,7 @@ class StatePduStore(SQLBaseStore):
     """A collection of queries for handling state PDUs.
     """
 
-    def persist_state(self, prev_pdus, **cols):
+    def _persist_state_txn(self, txn, prev_pdus, cols):
         """Inserts a state PDU into the database
 
         Args:
@@ -493,12 +483,6 @@ class StatePduStore(SQLBaseStore):
             prev_pdus (list)
             **cols: The columns to insert into the PdusTable and StatePdusTable
         """
-
-        return self._db_pool.runInteraction(
-            self._persist_state, prev_pdus, cols
-        )
-
-    def _persist_state(self, txn, prev_pdus, cols):
         pdu_entry = PdusTable.EntryType(
             **{k: cols.get(k, None) for k in PdusTable.fields}
         )
diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py
index 58590e4fcd..938b57bec9 100644
--- a/tests/federation/test_federation.py
+++ b/tests/federation/test_federation.py
@@ -58,7 +58,7 @@ class FederationTestCase(unittest.TestCase):
         self.mock_persistence = Mock(spec=[
             "get_current_state_for_context",
             "get_pdu",
-            "persist_pdu",
+            "persist_event",
             "update_min_depth_for_context",
             "prep_send_transaction",
             "delivered_txn",