diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 90 |
1 files changed, 61 insertions, 29 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 22d6257a9f..be61147b9b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import sys +import threading +import time -from synapse.api.errors import StoreError -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches.descriptors import Cache -from synapse.storage.engines import PostgresEngine +from six import PY2, iteritems, iterkeys, itervalues +from six.moves import intern, range +from canonicaljson import json from prometheus_client import Histogram from twisted.internet import defer -import sys -import time -import threading - -from six import itervalues, iterkeys, iteritems -from six.moves import intern, range +from synapse.api.errors import StoreError +from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import Cache +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) @@ -221,7 +221,7 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, - logging_context, func, *args, **kwargs): + func, *args, **kwargs): start = time.time() txn_id = self._TXN_ID @@ -285,8 +285,7 @@ class SQLBaseStore(object): end = time.time() duration = end - start - if logging_context is not None: - logging_context.add_database_transaction(duration) + LoggingContext.current_context().add_database_transaction(duration) transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) @@ -310,19 +309,21 @@ class SQLBaseStore(object): Returns: Deferred: The result of func """ - current_context = LoggingContext.current_context() - after_callbacks = [] exception_callbacks = [] - def inner_func(conn, *args, **kwargs): - return self._new_transaction( - conn, desc, after_callbacks, exception_callbacks, current_context, - func, *args, **kwargs + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warn( + "Starting db txn '%s' from sentinel context", + desc, ) try: - result = yield self.runWithConnection(inner_func, *args, **kwargs) + result = yield self.runWithConnection( + self._new_transaction, + desc, after_callbacks, exception_callbacks, func, + *args, **kwargs + ) for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) @@ -347,22 +348,25 @@ class SQLBaseStore(object): Returns: Deferred: The result of func """ - current_context = LoggingContext.current_context() + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warn( + "Starting db connection from sentinel context: metrics will be lost", + ) + parent_context = None start_time = time.time() def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection") as context: + with LoggingContext("runWithConnection", parent_context) as context: sched_duration_sec = time.time() - start_time sql_scheduling_timer.observe(sched_duration_sec) - current_context.add_database_scheduled(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() - current_context.copy_to(context) - return func(conn, *args, **kwargs) with PreserveLoggingContext(): @@ -1147,17 +1151,16 @@ class SQLBaseStore(object): defer.returnValue(retval) def get_user_count_txn(self, txn): - """Get a total number of registerd users in the users list. + """Get a total number of registered users in the users list. Args: txn : Transaction object Returns: - defer.Deferred: resolves to int + int : number of users """ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" txn.execute(sql_count) - count = txn.fetchone()[0] - defer.returnValue(count) + return txn.fetchone()[0] def _simple_search_list(self, table, term, col, retcols, desc="_simple_search_list"): @@ -1214,3 +1217,32 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass + + +def db_to_json(db_content): + """ + Take some data from a database row and return a JSON-decoded object. + + Args: + db_content (memoryview|buffer|bytes|bytearray|unicode) + """ + # psycopg2 on Python 3 returns memoryview objects, which we need to + # cast to bytes to decode + if isinstance(db_content, memoryview): + db_content = db_content.tobytes() + + # psycopg2 on Python 2 returns buffer objects, which we need to cast to + # bytes to decode + if PY2 and isinstance(db_content, buffer): + db_content = bytes(db_content) + + # Decode it to a Unicode string before feeding it to json.loads, so we + # consistenty get a Unicode-containing object out. + if isinstance(db_content, (bytes, bytearray)): + db_content = db_content.decode('utf8') + + try: + return json.loads(db_content) + except Exception: + logging.warning("Tried to decode '%r' as JSON and failed", db_content) + raise |