summary refs log tree commit diff
path: root/synapse/storage/event_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/event_federation.py')
-rw-r--r--synapse/storage/event_federation.py87
1 files changed, 51 insertions, 36 deletions
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 032334bfd6..74b4e23590 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 from syutil.base64util import encode_base64
 
 import logging
@@ -96,11 +96,23 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
+    @cached()
+    def get_latest_event_ids_in_room(self, room_id):
+        return self._simple_select_onecol(
+            table="event_forward_extremities",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="event_id",
+            desc="get_latest_event_ids_in_room",
+        )
+
     def _get_latest_events_in_room(self, txn, room_id):
         sql = (
             "SELECT e.event_id, e.depth FROM events as e "
             "INNER JOIN event_forward_extremities as f "
             "ON e.event_id = f.event_id "
+            "AND e.room_id = f.room_id "
             "WHERE f.room_id = ?"
         )
 
@@ -153,7 +165,7 @@ class EventFederationStore(SQLBaseStore):
         results = self._get_prev_events_and_state(
             txn,
             event_id,
-            is_state=1,
+            is_state=True,
         )
 
         return [(e_id, h, ) for e_id, h, _ in results]
@@ -164,7 +176,7 @@ class EventFederationStore(SQLBaseStore):
         }
 
         if is_state is not None:
-            keyvalues["is_state"] = is_state
+            keyvalues["is_state"] = bool(is_state)
 
         res = self._simple_select_list_txn(
             txn,
@@ -242,7 +254,6 @@ class EventFederationStore(SQLBaseStore):
                     "room_id": room_id,
                     "min_depth": depth,
                 },
-                or_replace=True,
             )
 
     def _handle_prev_events(self, txn, outlier, event_id, prev_events,
@@ -251,19 +262,19 @@ class EventFederationStore(SQLBaseStore):
         For the given event, update the event edges table and forward and
         backward extremities tables.
         """
-        for e_id, _ in prev_events:
-            # TODO (erikj): This could be done as a bulk insert
-            self._simple_insert_txn(
-                txn,
-                table="event_edges",
-                values={
+        self._simple_insert_many_txn(
+            txn,
+            table="event_edges",
+            values=[
+                {
                     "event_id": event_id,
                     "prev_event_id": e_id,
                     "room_id": room_id,
-                    "is_state": 0,
-                },
-                or_ignore=True,
-            )
+                    "is_state": False,
+                }
+                for e_id, _ in prev_events
+            ],
+        )
 
         # Update the extremities table if this is not an outlier.
         if not outlier:
@@ -281,33 +292,33 @@ class EventFederationStore(SQLBaseStore):
             # We only insert as a forward extremity the new event if there are
             # no other events that reference it as a prev event
             query = (
-                "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
-                "SELECT ?, ? WHERE NOT EXISTS ("
-                "SELECT 1 FROM %(event_edges)s WHERE "
-                "prev_event_id = ? "
-                ")"
-            ) % {
-                "table": "event_forward_extremities",
-                "event_edges": "event_edges",
-            }
+                "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
+            )
 
-            logger.debug("query: %s", query)
+            txn.execute(query, (event_id,))
+
+            if not txn.fetchone():
+                query = (
+                    "INSERT INTO event_forward_extremities"
+                    " (event_id, room_id)"
+                    " VALUES (?, ?)"
+                )
 
-            txn.execute(query, (event_id, room_id, event_id))
+                txn.execute(query, (event_id, room_id))
 
             # Insert all the prev_events as a backwards thing, they'll get
             # deleted in a second if they're incorrect anyway.
-            for e_id, _ in prev_events:
-                # TODO (erikj): This could be done as a bulk insert
-                self._simple_insert_txn(
-                    txn,
-                    table="event_backward_extremities",
-                    values={
+            self._simple_insert_many_txn(
+                txn,
+                table="event_backward_extremities",
+                values=[
+                    {
                         "event_id": e_id,
                         "room_id": room_id,
-                    },
-                    or_ignore=True,
-                )
+                    }
+                    for e_id, _ in prev_events
+                ],
+            )
 
             # Also delete from the backwards extremities table all ones that
             # reference events that we have already seen
@@ -321,6 +332,10 @@ class EventFederationStore(SQLBaseStore):
             )
             txn.execute(query)
 
+            txn.call_after(
+                self.get_latest_event_ids_in_room.invalidate, room_id
+            )
+
     def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
@@ -400,7 +415,7 @@ class EventFederationStore(SQLBaseStore):
 
         query = (
             "SELECT prev_event_id FROM event_edges "
-            "WHERE room_id = ? AND event_id = ? AND is_state = 0 "
+            "WHERE room_id = ? AND event_id = ? AND is_state = ? "
             "LIMIT ?"
         )
 
@@ -409,7 +424,7 @@ class EventFederationStore(SQLBaseStore):
             for event_id in front:
                 txn.execute(
                     query,
-                    (room_id, event_id, limit - len(event_results))
+                    (room_id, event_id, False, limit - len(event_results))
                 )
 
                 for e_id, in txn.fetchall():