summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/events.py77
-rw-r--r--synapse/storage/roommember.py27
-rw-r--r--synapse/storage/state.py28
4 files changed, 72 insertions, 62 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4f4c723c5b..6c32773f25 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -115,7 +115,7 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
+        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5cbe8c5978..1a7d4c5199 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore):
                 len(events_and_contexts)
             )
 
-        state_group_id_manager = self._state_groups_id_gen.get_next_mult(
-            len(events_and_contexts)
-        )
         with stream_ordering_manager as stream_orderings:
-            with state_group_id_manager as state_group_ids:
-                for (event, context), stream, state_group_id in zip(
-                    events_and_contexts, stream_orderings, state_group_ids
-                ):
-                    event.internal_metadata.stream_ordering = stream
-                    # Assign a state group_id in case a new id is needed for
-                    # this context. In theory we only need to assign this
-                    # for contexts that have current_state and aren't outliers
-                    # but that make the code more complicated. Assigning an ID
-                    # per event only causes the state_group_ids to grow as fast
-                    # as the stream_ordering so in practise shouldn't be a problem.
-                    context.new_state_group_id = state_group_id
-
-                chunks = [
-                    events_and_contexts[x:x + 100]
-                    for x in xrange(0, len(events_and_contexts), 100)
-                ]
+            for (event, context), stream, in zip(
+                events_and_contexts, stream_orderings
+            ):
+                event.internal_metadata.stream_ordering = stream
 
-                for chunk in chunks:
-                    # We can't easily parallelize these since different chunks
-                    # might contain the same event. :(
-                    yield self.runInteraction(
-                        "persist_events",
-                        self._persist_events_txn,
-                        events_and_contexts=chunk,
-                        backfilled=backfilled,
-                        delete_existing=delete_existing,
-                    )
-                    persist_event_counter.inc_by(len(chunk))
+            chunks = [
+                events_and_contexts[x:x + 100]
+                for x in xrange(0, len(events_and_contexts), 100)
+            ]
+
+            for chunk in chunks:
+                # We can't easily parallelize these since different chunks
+                # might contain the same event. :(
+                yield self.runInteraction(
+                    "persist_events",
+                    self._persist_events_txn,
+                    events_and_contexts=chunk,
+                    backfilled=backfilled,
+                    delete_existing=delete_existing,
+                )
+                persist_event_counter.inc_by(len(chunk))
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
@@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore):
                        delete_existing=False):
         try:
             with self._stream_id_gen.get_next() as stream_ordering:
-                with self._state_groups_id_gen.get_next() as state_group_id:
-                    event.internal_metadata.stream_ordering = stream_ordering
-                    context.new_state_group_id = state_group_id
-                    yield self.runInteraction(
-                        "persist_event",
-                        self._persist_event_txn,
-                        event=event,
-                        context=context,
-                        current_state=current_state,
-                        backfilled=backfilled,
-                        delete_existing=delete_existing,
-                    )
-                    persist_event_counter.inc()
+                event.internal_metadata.stream_ordering = stream_ordering
+                yield self.runInteraction(
+                    "persist_event",
+                    self._persist_event_txn,
+                    event=event,
+                    context=context,
+                    current_state=current_state,
+                    backfilled=backfilled,
+                    delete_existing=delete_existing,
+                )
+                persist_event_counter.inc()
         except _RollbackButIsFineException:
             pass
 
@@ -528,7 +515,7 @@ class EventsStore(SQLBaseStore):
                 # Add an entry to the ex_outlier_stream table to replicate the
                 # change in outlier status to our workers.
                 stream_order = event.internal_metadata.stream_ordering
-                state_group_id = context.state_group or context.new_state_group_id
+                state_group_id = context.state_group
                 self._simple_insert_txn(
                     txn,
                     table="ex_outlier_stream",
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index cab1660830..6ab10db328 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
             desc="who_forgot"
         )
 
-    def get_joined_users_from_context(self, room_id, state_group, state_ids):
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, state_ids
+            event.room_id, state_group, context.current_state_ids, event=event,
+        )
+
+    def get_joined_users_from_state(self, room_id, state_group, state_ids):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            room_id, state_group, state_ids,
         )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context):
+                                       cache_context, event=None):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
         # with a state_group of None are likely to be different.
@@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
             desc="_get_joined_users_from_context",
         )
 
-        defer.returnValue(set(row["user_id"] for row in rows))
+        users_in_room = set(row["user_id"] for row in rows)
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room.add(event.state_key)
+
+        defer.returnValue(users_in_room)
 
     def is_host_joined(self, room_id, host, state_group, state_ids):
         if not state_group:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b1d461fef5..ec551b0b4f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
             for group, event_id_map in group_to_ids.items()
         })
 
+    def _have_persisted_state_group_txn(self, txn, state_group):
+        txn.execute(
+            "SELECT count(*) FROM state_groups WHERE id = ?",
+            (state_group,)
+        )
+        row = txn.fetchone()
+        return row and row[0]
+
     def _store_mult_state_groups_txn(self, txn, events_and_contexts):
         state_groups = {}
         for event, context in events_and_contexts:
@@ -92,22 +100,19 @@ class StateStore(SQLBaseStore):
             if context.current_state_ids is None:
                 continue
 
-            if context.state_group is not None:
-                state_groups[event.event_id] = context.state_group
+            state_groups[event.event_id] = context.state_group
+
+            if self._have_persisted_state_group_txn(txn, context.state_group):
+                logger.info("Already persisted state_group: %r", context.state_group)
                 continue
 
             state_event_ids = dict(context.current_state_ids)
 
-            if event.is_state():
-                state_event_ids[(event.type, event.state_key)] = event.event_id
-
-            state_group = context.new_state_group_id
-
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": state_group,
+                    "id": context.state_group,
                     "room_id": event.room_id,
                     "event_id": event.event_id,
                 },
@@ -118,7 +123,7 @@ class StateStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": state_group,
+                        "state_group": context.state_group,
                         "room_id": event.room_id,
                         "type": key[0],
                         "state_key": key[1],
@@ -127,7 +132,6 @@ class StateStore(SQLBaseStore):
                     for key, state_id in state_event_ids.items()
                 ],
             )
-            state_groups[event.event_id] = state_group
 
         self._simple_insert_many_txn(
             txn,
@@ -527,5 +531,5 @@ class StateStore(SQLBaseStore):
             "get_all_new_state_groups", get_all_new_state_groups_txn
         )
 
-    def get_state_stream_token(self):
-        return self._state_groups_id_gen.get_current_token()
+    def get_next_state_group(self):
+        return self._state_groups_id_gen.get_next()