diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 58b73af7d2..6f54036d67 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,6 +16,7 @@ import logging
from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine
@@ -27,10 +28,6 @@ from twisted.internet import defer
import sys
import time
import threading
-import os
-
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
logger = logging.getLogger(__name__)
@@ -52,13 +49,17 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
+ __slots__ = [
+ "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ ]
- def __init__(self, txn, name, database_engine, after_callbacks):
+ def __init__(self, txn, name, database_engine, after_callbacks,
+ final_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
+ object.__setattr__(self, "final_callbacks", final_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -67,6 +68,9 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
+ def call_finally(self, callback, *args, **kwargs):
+ self.final_callbacks.append((callback, args, kwargs))
+
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -217,8 +221,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, logging_context,
- func, *args, **kwargs):
+ def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+ logging_context, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -237,7 +241,8 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks
+ txn, name, self.database_engine, after_callbacks,
+ final_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -298,6 +303,7 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
after_callbacks = []
+ final_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
@@ -309,7 +315,7 @@ class SQLBaseStore(object):
current_context.copy_to(context)
return self._new_transaction(
- conn, desc, after_callbacks, current_context,
+ conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
)
@@ -318,9 +324,13 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
- finally:
+
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
+ finally:
+ for after_callback, after_args, after_kwargs in final_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
defer.returnValue(result)
@defer.inlineCallbacks
@@ -425,6 +435,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
+ def _simple_insert_many(self, table, values, desc):
+ return self.runInteraction(
+ desc, self._simple_insert_many_txn, table, values
+ )
+
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
@@ -936,7 +951,7 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_finally(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
|