summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py56
1 files changed, 45 insertions, 11 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 456e4bd45d..dbc0e49c1f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,6 +15,10 @@
 
 from ._base import SQLBaseStore
 
+from twisted.internet import defer
+
+from synapse.util.stringutils import random_string
+
 import logging
 
 logger = logging.getLogger(__name__)
@@ -89,29 +93,31 @@ class StateStore(SQLBaseStore):
 
         state_group = context.state_group
         if not state_group:
-            state_group = self._simple_insert_txn(
+            state_group = self._state_groups_id_gen.get_next_txn(txn)
+            self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
+                    "id": state_group,
                     "room_id": event.room_id,
                     "event_id": event.event_id,
                 },
-                or_ignore=True,
             )
 
-            for state in state_events.values():
-                self._simple_insert_txn(
-                    txn,
-                    table="state_groups_state",
-                    values={
+            self._simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                values=[
+                    {
                         "state_group": state_group,
                         "room_id": state.room_id,
                         "type": state.type,
                         "state_key": state.state_key,
                         "event_id": state.event_id,
-                    },
-                    or_ignore=True,
-                )
+                    }
+                    for state in state_events.values()
+                ],
+            )
 
         self._simple_insert_txn(
             txn,
@@ -120,5 +126,33 @@ class StateStore(SQLBaseStore):
                 "state_group": state_group,
                 "event_id": event.event_id,
             },
-            or_replace=True,
         )
+
+    @defer.inlineCallbacks
+    def get_current_state(self, room_id, event_type=None, state_key=""):
+        def f(txn):
+            sql = (
+                "SELECT event_id FROM current_state_events"
+                " WHERE room_id = ? "
+            )
+
+            if event_type and state_key is not None:
+                sql += " AND type = ? AND state_key = ? "
+                args = (room_id, event_type, state_key)
+            elif event_type:
+                sql += " AND type = ?"
+                args = (room_id, event_type)
+            else:
+                args = (room_id, )
+
+            txn.execute(sql, args)
+            results = self.cursor_to_dict(txn)
+
+            return self._parse_events_txn(txn, results)
+
+        events = yield self.runInteraction("get_current_state", f)
+        defer.returnValue(events)
+
+
+def _make_group_id(clock):
+    return str(int(clock.time_msec())) + random_string(5)