summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-05-05 17:32:21 +0100
committerMark Haines <mark.haines@matrix.org>2015-05-05 17:32:21 +0100
commitd18f37e026a02b4e899bc96e600850007a613189 (patch)
tree66faf9320e35abcc03d0ff31a767a0de2fefddc8 /synapse/storage
parentSYN-369: Add comments to the sequence number logic in the cache (diff)
downloadsynapse-d18f37e026a02b4e899bc96e600850007a613189.tar.xz
Collect the invalidate callbacks on the transaction object rather than passing around a separate list
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py18
-rw-r--r--synapse/storage/event_federation.py10
-rw-r--r--synapse/storage/events.py48
-rw-r--r--synapse/storage/room.py4
-rw-r--r--synapse/storage/roommember.py8
-rw-r--r--synapse/storage/signatures.py12
-rw-r--r--synapse/storage/state.py2
7 files changed, 51 insertions, 51 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 579ed56377..ccf9697fa3 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -185,12 +185,16 @@ class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
-    __slots__ = ["txn", "name", "database_engine"]
+    __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
 
-    def __init__(self, txn, name, database_engine):
+    def __init__(self, txn, name, database_engine, after_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
+        object.__setattr__(self, "after_callbacks", after_callbacks)
+
+    def call_after(self, callback, *args):
+        self.after_callbacks.append((callback, args))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -336,6 +340,8 @@ class SQLBaseStore(object):
 
         start_time = time.time() * 1000
 
+        after_callbacks = []
+
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
                 if self.database_engine.is_connection_closed(conn):
@@ -360,10 +366,10 @@ class SQLBaseStore(object):
                     while True:
                         try:
                             txn = conn.cursor()
-                            return func(
-                                LoggingTransaction(txn, name, self.database_engine),
-                                *args, **kwargs
+                            txn = LoggingTransaction(
+                                txn, name, self.database_engine, after_callbacks
                             )
+                            return func(txn, *args, **kwargs)
                         except self.database_engine.module.OperationalError as e:
                             # This can happen if the database disappears mid
                             # transaction.
@@ -412,6 +418,8 @@ class SQLBaseStore(object):
             result = yield self._db_pool.runWithConnection(
                 inner_func, *args, **kwargs
             )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3cd3fbdc9b..893344eff3 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -241,7 +241,7 @@ class EventFederationStore(SQLBaseStore):
 
         return int(min_depth) if min_depth is not None else None
 
-    def _update_min_depth_for_room_txn(self, txn, invalidates, room_id, depth):
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
         min_depth = self._get_min_depth_interaction(txn, room_id)
 
         do_insert = depth < min_depth if min_depth else True
@@ -256,8 +256,8 @@ class EventFederationStore(SQLBaseStore):
                 },
             )
 
-    def _handle_prev_events(self, txn, invalidates, outlier, event_id,
-                            prev_events, room_id):
+    def _handle_prev_events(self, txn, outlier, event_id, prev_events,
+                            room_id):
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -330,9 +330,9 @@ class EventFederationStore(SQLBaseStore):
             )
             txn.execute(query)
 
-            invalidates.append((
+            txn.call_after(
                 self.get_latest_event_ids_in_room.invalidate, room_id
-            ))
+            )
 
     def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7dc49ceed6..17f9d27289 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -42,7 +42,7 @@ class EventsStore(SQLBaseStore):
             stream_ordering = self.min_token
 
         try:
-            invalidates = yield self.runInteraction(
+            yield self.runInteraction(
                 "persist_event",
                 self._persist_event_txn,
                 event=event,
@@ -52,11 +52,6 @@ class EventsStore(SQLBaseStore):
                 is_new_state=is_new_state,
                 current_state=current_state,
             )
-            for invalidated in invalidates:
-                invalidated_callback = invalidated[0]
-                invalidated_args = invalidated[1:]
-                invalidated_callback(*invalidated_args)
-
         except _RollbackButIsFineException:
             pass
 
@@ -96,10 +91,9 @@ class EventsStore(SQLBaseStore):
     def _persist_event_txn(self, txn, event, context, backfilled,
                            stream_ordering=None, is_new_state=True,
                            current_state=None):
-        invalidates = []
 
         # Remove the any existing cache entries for the event_id
-        invalidates.append((self._invalidate_get_event_cache, event.event_id))
+        txn.call_after(self._invalidate_get_event_cache, event.event_id)
 
         if stream_ordering is None:
             with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
@@ -121,10 +115,12 @@ class EventsStore(SQLBaseStore):
 
             for s in current_state:
                 if s.type == EventTypes.Member:
-                    invalidates.extend([
-                        (self.get_rooms_for_user.invalidate, s.state_key),
-                        (self.get_joined_hosts_for_room.invalidate, s.room_id),
-                    ])
+                    txn.call_after(
+                        self.get_rooms_for_user.invalidate, s.state_key
+                    )
+                    txn.call_after(
+                        self.get_joined_hosts_for_room.invalidate, s.room_id
+                    )
                 self._simple_insert_txn(
                     txn,
                     "current_state_events",
@@ -161,11 +157,10 @@ class EventsStore(SQLBaseStore):
         outlier = event.internal_metadata.is_outlier()
 
         if not outlier:
-            self._store_state_groups_txn(txn, invalidates, event, context)
+            self._store_state_groups_txn(txn, event, context)
 
             self._update_min_depth_for_room_txn(
                 txn,
-                invalidates,
                 event.room_id,
                 event.depth
             )
@@ -207,11 +202,10 @@ class EventsStore(SQLBaseStore):
                     sql,
                     (False, event.event_id,)
                 )
-            return invalidates
+            return
 
         self._handle_prev_events(
             txn,
-            invalidates,
             outlier=outlier,
             event_id=event.event_id,
             prev_events=event.prev_events,
@@ -219,13 +213,13 @@ class EventsStore(SQLBaseStore):
         )
 
         if event.type == EventTypes.Member:
-            self._store_room_member_txn(txn, invalidates, event)
+            self._store_room_member_txn(txn, event)
         elif event.type == EventTypes.Name:
-            self._store_room_name_txn(txn, invalidates, event)
+            self._store_room_name_txn(txn, event)
         elif event.type == EventTypes.Topic:
-            self._store_room_topic_txn(txn, invalidates, event)
+            self._store_room_topic_txn(txn, event)
         elif event.type == EventTypes.Redaction:
-            self._store_redaction(txn, invalidates, event)
+            self._store_redaction(txn, event)
 
         event_dict = {
             k: v
@@ -295,20 +289,20 @@ class EventsStore(SQLBaseStore):
 
         if context.rejected:
             self._store_rejections_txn(
-                txn, invalidates, event.event_id, context.rejected
+                txn, event.event_id, context.rejected
             )
 
         for hash_alg, hash_base64 in event.hashes.items():
             hash_bytes = decode_base64(hash_base64)
             self._store_event_content_hash_txn(
-                txn, invalidates, event.event_id, hash_alg, hash_bytes,
+                txn, event.event_id, hash_alg, hash_bytes,
             )
 
         for prev_event_id, prev_hashes in event.prev_events:
             for alg, hash_base64 in prev_hashes.items():
                 hash_bytes = decode_base64(hash_base64)
                 self._store_prev_event_hash_txn(
-                    txn, invalidates, event.event_id, prev_event_id, alg,
+                    txn, event.event_id, prev_event_id, alg,
                     hash_bytes
                 )
 
@@ -325,7 +319,7 @@ class EventsStore(SQLBaseStore):
 
         (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
         self._store_event_reference_hash_txn(
-            txn, invalidates, event.event_id, ref_alg, ref_hash_bytes
+            txn, event.event_id, ref_alg, ref_hash_bytes
         )
 
         if event.is_state():
@@ -372,11 +366,11 @@ class EventsStore(SQLBaseStore):
                     }
                 )
 
-        return invalidates
+        return
 
-    def _store_redaction(self, txn, invalidates, event):
+    def _store_redaction(self, txn, event):
         # invalidate the cache for the redacted event
-        invalidates.append((self._invalidate_get_event_cache, event.redacts))
+        txn.call_after(self._invalidate_get_event_cache, event.redacts)
         txn.execute(
             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
             (event.event_id, event.redacts)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index d42d7ff0e3..f956377632 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -162,7 +162,7 @@ class RoomStore(SQLBaseStore):
 
         defer.returnValue(ret)
 
-    def _store_room_topic_txn(self, txn, invalidates, event):
+    def _store_room_topic_txn(self, txn, event):
         if hasattr(event, "content") and "topic" in event.content:
             self._simple_insert_txn(
                 txn,
@@ -174,7 +174,7 @@ class RoomStore(SQLBaseStore):
                 },
             )
 
-    def _store_room_name_txn(self, txn, invalidates, event):
+    def _store_room_name_txn(self, txn, event):
         if hasattr(event, "content") and "name" in event.content:
             self._simple_insert_txn(
                 txn,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 117da817ba..839c74f63a 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,7 +35,7 @@ RoomsForUser = namedtuple(
 
 class RoomMemberStore(SQLBaseStore):
 
-    def _store_room_member_txn(self, txn, invalidates, event):
+    def _store_room_member_txn(self, txn, event):
         """Store a room member in the database.
         """
         try:
@@ -64,10 +64,8 @@ class RoomMemberStore(SQLBaseStore):
             }
         )
 
-        invalidates.extend([
-            (self.get_rooms_for_user.invalidate, target_user_id),
-            (self.get_joined_hosts_for_room.invalidate, event.room_id),
-        ])
+        txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
+        txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
 
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e3979846e7..f051828630 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -39,8 +39,8 @@ class SignatureStore(SQLBaseStore):
         txn.execute(query, (event_id, ))
         return dict(txn.fetchall())
 
-    def _store_event_content_hash_txn(self, txn, invalidates, event_id,
-                                      algorithm, hash_bytes):
+    def _store_event_content_hash_txn(self, txn, event_id, algorithm,
+                                      hash_bytes):
         """Store a hash for a Event
         Args:
             txn (cursor):
@@ -101,8 +101,8 @@ class SignatureStore(SQLBaseStore):
         txn.execute(query, (event_id, ))
         return {k: v for k, v in txn.fetchall()}
 
-    def _store_event_reference_hash_txn(self, txn, invalidates, event_id,
-                                        algorithm, hash_bytes):
+    def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
+                                        hash_bytes):
         """Store a hash for a PDU
         Args:
             txn (cursor):
@@ -184,8 +184,8 @@ class SignatureStore(SQLBaseStore):
             hashes[algorithm] = hash_bytes
         return results
 
-    def _store_prev_event_hash_txn(self, txn, invalidates, event_id,
-                                   prev_event_id, algorithm, hash_bytes):
+    def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
+                                   algorithm, hash_bytes):
         self._simple_insert_txn(
             txn,
             "event_edge_hashes",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 35d11c27cc..7e55e8bed6 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
             f,
         )
 
-    def _store_state_groups_txn(self, txn, invalidates, event, context):
+    def _store_state_groups_txn(self, txn, event, context):
         if context.current_state is None:
             return