summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py170
-rw-r--r--synapse/storage/_base.py12
-rw-r--r--synapse/storage/feedback.py4
-rw-r--r--synapse/storage/pdu.py29
-rw-r--r--synapse/storage/room.py10
-rw-r--r--synapse/storage/roommember.py28
-rw-r--r--synapse/storage/stream.py13
7 files changed, 190 insertions, 76 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 38ab03c45c..e8faba3eeb 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -57,20 +57,22 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, backfilled=False):
-        if event.type == RoomMemberEvent.TYPE:
-            yield self._store_room_member(event)
-        elif event.type == FeedbackEvent.TYPE:
-            yield self._store_feedback(event)
-#        elif event.type == RoomConfigEvent.TYPE:
-#            yield self._store_room_config(event)
-        elif event.type == RoomNameEvent.TYPE:
-            yield self._store_room_name(event)
-        elif event.type == RoomTopicEvent.TYPE:
-            yield self._store_room_topic(event)
-
-        ret = yield self._store_event(event, backfilled)
-        defer.returnValue(ret)
+    def persist_event(self, event=None, backfilled=False, pdu=None):
+        stream_ordering = None
+        if backfilled:
+            if not self.min_token_deferred.called:
+                yield self.min_token_deferred
+            self.min_token -= 1
+            stream_ordering = self.min_token
+
+        latest = yield self._db_pool.runInteraction(
+            self._persist_pdu_event_txn,
+            pdu=pdu,
+            event=event,
+            backfilled=backfilled,
+            stream_ordering=stream_ordering,
+        )
+        defer.returnValue(latest)
 
     @defer.inlineCallbacks
     def get_event(self, event_id):
@@ -89,12 +91,44 @@ class DataStore(RoomMemberStore, RoomStore,
         event = self._parse_event_from_row(events_dict)
         defer.returnValue(event)
 
-    @defer.inlineCallbacks
+    def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
+                               backfilled=False, stream_ordering=None):
+        if pdu is not None:
+            self._persist_event_pdu_txn(txn, pdu)
+        if event is not None:
+            return self._persist_event_txn(
+                txn, event, backfilled, stream_ordering
+            )
+
+    def _persist_event_pdu_txn(self, txn, pdu):
+        cols = dict(pdu.__dict__)
+        unrec_keys = dict(pdu.unrecognized_keys)
+        del cols["content"]
+        del cols["prev_pdus"]
+        cols["content_json"] = json.dumps(pdu.content)
+        cols["unrecognized_keys"] = json.dumps(unrec_keys)
+
+        logger.debug("Persisting: %s", repr(cols))
+
+        if pdu.is_state:
+            self._persist_state_txn(txn, pdu.prev_pdus, cols)
+        else:
+            self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
+
+        self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
+
     @log_function
-    def _store_event(self, event, backfilled):
-        # FIXME (erikj): This should be removed when we start amalgamating
-        # event and pdu storage
-        yield self.hs.get_federation().fill_out_prev_events(event)
+    def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None):
+        if event.type == RoomMemberEvent.TYPE:
+            self._store_room_member_txn(txn, event)
+        elif event.type == FeedbackEvent.TYPE:
+            self._store_feedback_txn(txn,event)
+#        elif event.type == RoomConfigEvent.TYPE:
+#            self._store_room_config_txn(txn, event)
+        elif event.type == RoomNameEvent.TYPE:
+            self._store_room_name_txn(txn, event)
+        elif event.type == RoomTopicEvent.TYPE:
+            self._store_room_topic_txn(txn, event)
 
         vals = {
             "topological_ordering": event.depth,
@@ -105,17 +139,14 @@ class DataStore(RoomMemberStore, RoomStore,
             "processed": True,
         }
 
+        if stream_ordering is not None:
+            vals["stream_ordering"] = stream_ordering
+
         if hasattr(event, "outlier"):
             vals["outlier"] = event.outlier
         else:
             vals["outlier"] = False
 
-        if backfilled:
-            if not self.min_token_deferred.called:
-                yield self.min_token_deferred
-            self.min_token -= 1
-            vals["stream_ordering"] = self.min_token
-
         unrec = {
             k: v
             for k, v in event.get_full_dict().items()
@@ -124,7 +155,7 @@ class DataStore(RoomMemberStore, RoomStore,
         vals["unrecognized_keys"] = json.dumps(unrec)
 
         try:
-            yield self._simple_insert("events", vals)
+            self._simple_insert_txn(txn, "events", vals)
         except:
             logger.exception(
                 "Failed to persist, probably duplicate: %s",
@@ -143,9 +174,10 @@ class DataStore(RoomMemberStore, RoomStore,
             if hasattr(event, "prev_state"):
                 vals["prev_state"] = event.prev_state
 
-            yield self._simple_insert("state_events", vals)
+            self._simple_insert_txn(txn, "state_events", vals)
 
-            yield self._simple_insert(
+            self._simple_insert_txn(
+                txn,
                 "current_state_events",
                 {
                     "event_id": event.event_id,
@@ -155,8 +187,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 }
             )
 
-        latest = yield self.get_room_events_max_id()
-        defer.returnValue(latest)
+        return self._get_room_events_max_id_txn(txn)
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -192,6 +223,85 @@ class DataStore(RoomMemberStore, RoomStore,
         defer.returnValue(self.min_token)
 
 
+    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+        """Snapshot the room for an update by a user
+        Args:
+            room_id (synapse.types.RoomId): The room to snapshot.
+            user_id (synapse.types.UserId): The user to snapshot the room for.
+            state_type (str): Optional state type to snapshot.
+            state_key (str): Optional state key to snapshot.
+        Returns:
+            synapse.storage.Snapshot: A snapshot of the state of the room.
+        """
+        def _snapshot(txn):
+            membership_state = self._get_room_member(txn, user_id, room_id)
+            prev_pdus = self._get_latest_pdus_in_context(
+                txn, room_id
+            )
+            if state_type is not None and state_key is not None:
+                prev_state_pdu = self._get_current_state_pdu(
+                    txn, room_id, state_type, state_key
+                )
+            else:
+                prev_state_pdu = None
+
+            return Snapshot(
+                store=self,
+                room_id=room_id,
+                user_id=user_id,
+                prev_pdus=prev_pdus,
+                membership_state=membership_state,
+                state_type=state_type,
+                state_key=state_key,
+                prev_state_pdu=prev_state_pdu,
+            )
+
+        return self._db_pool.runInteraction(_snapshot)
+
+
+class Snapshot(object):
+    """Snapshot of the state of a room
+    Args:
+        store (DataStore): The datastore.
+        room_id (RoomId): The room of the snapshot.
+        user_id (UserId): The user this snapshot is for.
+        prev_pdus (list): The list of PDU ids this snapshot is after.
+        membership_state (RoomMemberEvent): The current state of the user in
+            the room.
+        state_type (str, optional): State type captured by the snapshot
+        state_key (str, optional): State key captured by the snapshot
+        prev_state_pdu (PduEntry, optional): pdu id of
+            the previous value of the state type and key in the room.
+    """
+
+    def __init__(self, store, room_id, user_id, prev_pdus,
+                 membership_state, state_type=None, state_key=None,
+                 prev_state_pdu=None):
+        self.store = store
+        self.room_id = room_id
+        self.user_id = user_id
+        self.prev_pdus = prev_pdus
+        self.membership_state = membership_state
+        self.state_type = state_type
+        self.state_key = state_key
+        self.prev_state_pdu = prev_state_pdu
+
+    def fill_out_prev_events(self, event):
+        if hasattr(event, "prev_events"):
+            return
+
+        es = [
+            "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
+        ]
+
+        event.prev_events = [e for e in es if e != event.event_id]
+
+        if self.prev_pdus:
+            event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
+        else:
+            event.depth = 0
+
+
 def schema_path(schema):
     """ Get a filesystem path for the named database schema
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 75aab2d3b9..33d56f47ce 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -86,16 +86,18 @@ class SQLBaseStore(object):
             table : string giving the table name
             values : dict of new column names and values for them
         """
+        return self._db_pool.runInteraction(
+            self._simple_insert_txn, table, values,
+        )
+
+    def _simple_insert_txn(self, txn, table, values):
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
             ", ".join(k for k in values),
             ", ".join("?" for k in values)
         )
-
-        def func(txn):
-            txn.execute(sql, values.values())
-            return txn.lastrowid
-        return self._db_pool.runInteraction(func)
+        txn.execute(sql, values.values())
+        return txn.lastrowid
 
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False):
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
index 513b72d279..bac3dea955 100644
--- a/synapse/storage/feedback.py
+++ b/synapse/storage/feedback.py
@@ -20,8 +20,8 @@ from ._base import SQLBaseStore
 
 class FeedbackStore(SQLBaseStore):
 
-    def _store_feedback(self, event):
-        return self._simple_insert("feedback", {
+    def _store_feedback_txn(self, txn, event):
+        self._simple_insert_txn(txn, "feedback", {
             "event_id": event.event_id,
             "feedback_type": event.content["type"],
             "room_id": event.room_id,
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 7655f43ede..9fd44f2454 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -114,7 +114,7 @@ class PduStore(SQLBaseStore):
 
         return self._get_pdu_tuples(txn, res)
 
-    def persist_pdu(self, prev_pdus, **cols):
+    def _persist_pdu_txn(self, txn, prev_pdus, cols):
         """Inserts a (non-state) PDU into the database.
 
         Args:
@@ -122,11 +122,6 @@ class PduStore(SQLBaseStore):
             prev_pdus (list)
             **cols: The columns to insert into the PdusTable.
         """
-        return self._db_pool.runInteraction(
-            self._persist_pdu, prev_pdus, cols
-        )
-
-    def _persist_pdu(self, txn, prev_pdus, cols):
         entry = PdusTable.EntryType(
             **{k: cols.get(k, None) for k in PdusTable.fields}
         )
@@ -262,7 +257,7 @@ class PduStore(SQLBaseStore):
 
         return row[0] if row else None
 
-    def update_min_depth_for_context(self, context, depth):
+    def _update_min_depth_for_context_txn(self, txn, context, depth):
         """Update the minimum `depth` of the given context, which is the line
         on which we stop backfilling backwards.
 
@@ -270,11 +265,6 @@ class PduStore(SQLBaseStore):
             context (str)
             depth (int)
         """
-        return self._db_pool.runInteraction(
-            self._update_min_depth_for_context, context, depth
-        )
-
-    def _update_min_depth_for_context(self, txn, context, depth):
         min_depth = self._get_min_depth_interaction(txn, context)
 
         do_insert = depth < min_depth if min_depth else True
@@ -286,7 +276,7 @@ class PduStore(SQLBaseStore):
                 (context, depth)
             )
 
-    def get_latest_pdus_in_context(self, context):
+    def _get_latest_pdus_in_context(self, txn, context):
         """Get's a list of the most current pdus for a given context. This is
         used when we are sending a Pdu and need to fill out the `prev_pdus`
         key
@@ -295,11 +285,6 @@ class PduStore(SQLBaseStore):
             txn
             context
         """
-        return self._db_pool.runInteraction(
-            self._get_latest_pdus_in_context, context
-        )
-
-    def _get_latest_pdus_in_context(self, txn, context):
         query = (
             "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
             "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
@@ -485,7 +470,7 @@ class StatePduStore(SQLBaseStore):
     """A collection of queries for handling state PDUs.
     """
 
-    def persist_state(self, prev_pdus, **cols):
+    def _persist_state_txn(self, txn, prev_pdus, cols):
         """Inserts a state PDU into the database
 
         Args:
@@ -493,12 +478,6 @@ class StatePduStore(SQLBaseStore):
             prev_pdus (list)
             **cols: The columns to insert into the PdusTable and StatePdusTable
         """
-
-        return self._db_pool.runInteraction(
-            self._persist_state, prev_pdus, cols
-        )
-
-    def _persist_state(self, txn, prev_pdus, cols):
         pdu_entry = PdusTable.EntryType(
             **{k: cols.get(k, None) for k in PdusTable.fields}
         )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index a5751005ef..d1f1a232f8 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -129,8 +129,9 @@ class RoomStore(SQLBaseStore):
 
         defer.returnValue(ret)
 
-    def _store_room_topic(self, event):
-        return self._simple_insert(
+    def _store_room_topic_txn(self, txn, event):
+        self._simple_insert_txn(
+            txn,
             "topics",
             {
                 "event_id": event.event_id,
@@ -139,8 +140,9 @@ class RoomStore(SQLBaseStore):
             }
         )
 
-    def _store_room_name(self, event):
-        return self._simple_insert(
+    def _store_room_name_txn(self, txn, event):
+        self._simple_insert_txn(
+            txn,
             "room_names",
             {
                 "event_id": event.event_id,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4ad37af0f3..2746126e85 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -26,14 +26,14 @@ logger = logging.getLogger(__name__)
 
 class RoomMemberStore(SQLBaseStore):
 
-    @defer.inlineCallbacks
-    def _store_room_member(self, event):
+    def _store_room_member_txn(self, txn, event):
         """Store a room member in the database.
         """
         target_user_id = event.state_key
         domain = self.hs.parse_userid(target_user_id).domain
 
-        yield self._simple_insert(
+        self._simple_insert_txn(
+            txn,
             "room_memberships",
             {
                 "event_id": event.event_id,
@@ -50,13 +50,13 @@ class RoomMemberStore(SQLBaseStore):
                 "INSERT OR IGNORE INTO room_hosts (room_id, host) "
                 "VALUES (?, ?)"
             )
-            yield self._execute(None, sql, event.room_id, domain)
+            txn.execute(sql, (event.room_id, domain))
         else:
             sql = (
                 "DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
             )
 
-            yield self._execute(None, sql, event.room_id, domain)
+            txn.execute(sql, (event.room_id, domain))
 
     @defer.inlineCallbacks
     def get_room_member(self, user_id, room_id):
@@ -75,6 +75,24 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(rows[0] if rows else None)
 
+    def _get_room_member(self, txn, user_id, room_id):
+        sql = (
+            "SELECT e.* FROM events as e"
+            " INNER JOIN room_memberships as m"
+            " ON e.event_id = m.event_id"
+            " INNER JOIN current_state_events as c"
+            " ON m.event_id = c.event_id"
+            " WHERE m.user_id = ? and e.room_id = ?"
+            " LIMIT 1"
+        )
+        txn.execute(sql, (user_id, room_id))
+        rows = self.cursor_to_dict(txn)
+        if rows:
+            return self._parse_event_from_row(rows[0])
+        else:
+            return None
+
+
     def get_room_members(self, room_id, membership=None):
         """Retrieve the current room member list for a room.
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 6a22d5aead..4f42afc015 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -281,17 +281,20 @@ class StreamStore(SQLBaseStore):
             )
         )
 
-    @defer.inlineCallbacks
     def get_room_events_max_id(self):
-        res = yield self._execute_and_decode(
+        return self._db_pool.runInteraction(self._get_room_events_max_id_txn)
+
+    def _get_room_events_max_id_txn(self, txn):
+        txn.execute(
             "SELECT MAX(stream_ordering) as m FROM events"
         )
 
+        res = self.cursor_to_dict(txn)
+
         logger.debug("get_room_events_max_id: %s", res)
 
         if not res or not res[0] or not res[0]["m"]:
-            defer.returnValue("s1")
-            return
+            return "s1"
 
         key = res[0]["m"]
-        defer.returnValue("s%d" % (key,))
+        return "s%d" % (key,)