diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 22d6257a9f..be61147b9b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import sys
+import threading
+import time
-from synapse.api.errors import StoreError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches.descriptors import Cache
-from synapse.storage.engines import PostgresEngine
+from six import PY2, iteritems, iterkeys, itervalues
+from six.moves import intern, range
+from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
-import sys
-import time
-import threading
-
-from six import itervalues, iterkeys, iteritems
-from six.moves import intern, range
+from synapse.api.errors import StoreError
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import Cache
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
@@ -221,7 +221,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
- logging_context, func, *args, **kwargs):
+ func, *args, **kwargs):
start = time.time()
txn_id = self._TXN_ID
@@ -285,8 +285,7 @@ class SQLBaseStore(object):
end = time.time()
duration = end - start
- if logging_context is not None:
- logging_context.add_database_transaction(duration)
+ LoggingContext.current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@@ -310,19 +309,21 @@ class SQLBaseStore(object):
Returns:
Deferred: The result of func
"""
- current_context = LoggingContext.current_context()
-
after_callbacks = []
exception_callbacks = []
- def inner_func(conn, *args, **kwargs):
- return self._new_transaction(
- conn, desc, after_callbacks, exception_callbacks, current_context,
- func, *args, **kwargs
+ if LoggingContext.current_context() == LoggingContext.sentinel:
+ logger.warn(
+ "Starting db txn '%s' from sentinel context",
+ desc,
)
try:
- result = yield self.runWithConnection(inner_func, *args, **kwargs)
+ result = yield self.runWithConnection(
+ self._new_transaction,
+ desc, after_callbacks, exception_callbacks, func,
+ *args, **kwargs
+ )
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
@@ -347,22 +348,25 @@ class SQLBaseStore(object):
Returns:
Deferred: The result of func
"""
- current_context = LoggingContext.current_context()
+ parent_context = LoggingContext.current_context()
+ if parent_context == LoggingContext.sentinel:
+ logger.warn(
+ "Starting db connection from sentinel context: metrics will be lost",
+ )
+ parent_context = None
start_time = time.time()
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection") as context:
+ with LoggingContext("runWithConnection", parent_context) as context:
sched_duration_sec = time.time() - start_time
sql_scheduling_timer.observe(sched_duration_sec)
- current_context.add_database_scheduled(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- current_context.copy_to(context)
-
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
@@ -1147,17 +1151,16 @@ class SQLBaseStore(object):
defer.returnValue(retval)
def get_user_count_txn(self, txn):
- """Get a total number of registerd users in the users list.
+ """Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
- defer.Deferred: resolves to int
+ int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
- count = txn.fetchone()[0]
- defer.returnValue(count)
+ return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):
@@ -1214,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
+
+
+def db_to_json(db_content):
+ """
+ Take some data from a database row and return a JSON-decoded object.
+
+ Args:
+ db_content (memoryview|buffer|bytes|bytearray|unicode)
+ """
+ # psycopg2 on Python 3 returns memoryview objects, which we need to
+ # cast to bytes to decode
+ if isinstance(db_content, memoryview):
+ db_content = db_content.tobytes()
+
+ # psycopg2 on Python 2 returns buffer objects, which we need to cast to
+ # bytes to decode
+ if PY2 and isinstance(db_content, buffer):
+ db_content = bytes(db_content)
+
+ # Decode it to a Unicode string before feeding it to json.loads, so we
+ # consistenty get a Unicode-containing object out.
+ if isinstance(db_content, (bytes, bytearray)):
+ db_content = db_content.decode('utf8')
+
+ try:
+ return json.loads(db_content)
+ except Exception:
+ logging.warning("Tried to decode '%r' as JSON and failed", db_content)
+ raise
|