diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7a3f6c4662..c8cab45f77 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -46,6 +46,9 @@ from .tags import TagsStore
from .account_data import AccountDataStore
+from util.id_generators import IdGenerator, StreamIdGenerator
+
+
import logging
@@ -58,6 +61,22 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120*1000
+def get_datastore(hs):
+ logger.info("getting called!")
+
+ conn = hs.get_db_conn()
+ try:
+ cur = conn.cursor()
+ cur.execute("SELECT MIN(stream_ordering) FROM events",)
+ rows = cur.fetchall()
+ min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
+ min_token = min(min_token, -1)
+
+ return DataStore(conn, hs, min_token)
+ finally:
+ conn.close()
+
+
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
@@ -79,18 +98,36 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore
):
- def __init__(self, hs):
- super(DataStore, self).__init__(hs)
+ def __init__(self, db_conn, hs, min_stream_token):
self.hs = hs
- self.min_token_deferred = self._get_min_token()
- self.min_token = None
+ self.min_stream_token = min_stream_token
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering"
+ )
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+ self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+ self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
+ self._pushers_id_gen = IdGenerator("pushers", "id", self)
+ self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+ self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+
+ super(DataStore, self).__init__(hs)
+
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 90d7aee94a..5e77320540 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
import logging
from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
-from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
- self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
- self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
- self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
- self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
- self._pushers_id_gen = IdGenerator("pushers", "id", self)
- self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
- self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
- self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -345,7 +333,8 @@ class SQLBaseStore(object):
defer.returnValue(result)
- def cursor_to_dict(self, cursor):
+ @staticmethod
+ def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore:
raise
- @log_function
- def _simple_insert_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
- def _simple_insert_many_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_many_txn(txn, table, values):
if not values:
return
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none,
)
- def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ @classmethod
+ def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
- ret = self._simple_select_onecol_txn(
+ ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
else:
raise StoreError(404, "No row found")
- def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ @staticmethod
+ def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols
)
- def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+ @classmethod
+ def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -627,7 +620,7 @@ class SQLBaseStore(object):
)
txn.execute(sql)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols,
@@ -662,7 +655,8 @@ class SQLBaseStore(object):
defer.returnValue(results)
- def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+ @classmethod
+ def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -699,7 +693,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, values)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
@@ -726,7 +720,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ @staticmethod
+ def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -743,7 +738,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
- def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ @staticmethod
+ def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
@@ -784,7 +780,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
- def _simple_delete_txn(self, txn, table, keyvalues):
+ @staticmethod
+ def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba368a3eca..298cb9bada 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- start = self.min_token - 1
- self.min_token -= len(events_and_contexts) + 1
- stream_orderings = range(start, self.min_token, -1)
+ start = self.min_stream_token - 1
+ self.min_stream_token -= len(events_and_contexts) + 1
+ stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
@@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
+ self.min_stream_token -= 1
+ stream_ordering = self.min_stream_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c4232bdc65..c0593e23ee 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -31,7 +31,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
- self._receipts_stream_cache = _RoomStreamChangeCache()
+ self._receipts_stream_cache = _RoomStreamChangeCache(
+ self._receipts_id_gen.get_max_token(None)
+ )
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -377,11 +379,11 @@ class _RoomStreamChangeCache(object):
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
- def __init__(self, size_of_cache=10000):
+ def __init__(self, current_key, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
- self._earliest_key = None
+ self._earliest_key = current_key
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..e31bad258a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -444,19 +444,6 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0] if rows else 0
- @defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
-
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
-
- logger.debug("min_token is: %s", self.min_token)
-
- defer.returnValue(self.min_token)
-
@staticmethod
def _set_before_and_after(events, rows):
for event, row in zip(events, rows):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index ed9c91e5ea..4c39e07cbd 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
@@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
- def __init__(self, hs):
- super(TagsStore, self).__init__(hs)
-
- self._account_data_id_gen = StreamIdGenerator(
- "account_data_max_stream_id", "stream_id"
- )
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f58bf7fd2c..5c522f4ab9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
- def __init__(self, table, column):
+ def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
- self._current_max = None
+ cur = db_conn.cursor()
+ self._current_max = self._get_or_compute_current_max(cur)
+ cur.close()
+
self._unfinished_ids = deque()
- @defer.inlineCallbacks
def get_next(self, store):
"""
Usage:
with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
self._current_max += 1
next_id = self._current_max
@@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
@@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
if self._unfinished_ids:
- defer.returnValue(self._unfinished_ids[0] - 1)
+ return self._unfinished_ids[0] - 1
- defer.returnValue(self._current_max)
+ return self._current_max
def _get_or_compute_current_max(self, txn):
with self._lock:
|