diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 48 |
1 files changed, 41 insertions, 7 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b75b79df36..32c6677d47 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -152,8 +152,8 @@ class SQLBaseStore(object): def __init__(self, hs): self.hs = hs - self._db_pool = hs.get_db_pool() self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() self._previous_txn_total_time = 0 self._current_txn_total_time = 0 @@ -453,7 +453,9 @@ class SQLBaseStore(object): keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values insertion_values (dict): key/values to use when inserting - Returns: A deferred + Returns: + Deferred(bool): True if a new entry was created, False if an + existing one was updated. """ return self.runInteraction( desc, @@ -498,6 +500,10 @@ class SQLBaseStore(object): ) txn.execute(sql, allvalues.values()) + return True + else: + return False + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to @@ -810,11 +816,39 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def get_next_stream_id(self): - with self._next_stream_id_lock: - i = self._next_stream_id - self._next_stream_id += 1 - return i + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, + max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - 100000" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + txn.close() + + cache = { + row[0]: int(row[1]) + for row in rows + } + + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val class _RollbackButIsFineException(Exception): |