summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/__init__.py14
-rw-r--r--synapse/storage/account_data.py8
-rw-r--r--synapse/storage/events.py6
-rw-r--r--synapse/storage/presence.py4
-rw-r--r--synapse/storage/push_rule.py4
-rw-r--r--synapse/storage/pusher.py2
-rw-r--r--synapse/storage/receipts.py4
-rw-r--r--synapse/storage/registration.py6
-rw-r--r--synapse/storage/state.py2
-rw-r--r--synapse/storage/tags.py8
-rw-r--r--synapse/storage/transactions.py2
-rw-r--r--synapse/storage/util/id_generators.py69
12 files changed, 52 insertions, 77 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 9be1d12fac..f257721ea3 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "presence_stream", "stream_id"
         )
 
-        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
-        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
-        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
-        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
-        self._pushers_id_gen = IdGenerator("pushers", "id", self)
-        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
-        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+        self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
+        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+        self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
+        self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
 
         events_max = self._stream_id_gen.get_max_token()
         event_cache_prefill, min_event_val = self._get_cache_dict(
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 05f98a9a29..faddefe219 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -195,12 +195,12 @@ class AccountDataStore(SQLBaseStore):
             )
             self._update_max_stream_id(txn, next_id)
 
-        with (yield self._account_data_id_gen.get_next(self)) as next_id:
+        with self._account_data_id_gen.get_next() as next_id:
             yield self.runInteraction(
                 "add_room_account_data", add_account_data_txn, next_id
             )
 
-        result = yield self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_max_token()
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -234,12 +234,12 @@ class AccountDataStore(SQLBaseStore):
             )
             self._update_max_stream_id(txn, next_id)
 
-        with (yield self._account_data_id_gen.get_next(self)) as next_id:
+        with self._account_data_id_gen.get_next() as next_id:
             yield self.runInteraction(
                 "add_user_account_data", add_account_data_txn, next_id
             )
 
-        result = yield self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_max_token()
         defer.returnValue(result)
 
     def _update_max_stream_id(self, txn, next_id):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c0872dd7e2..60936500d8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
                 yield stream_orderings
             stream_ordering_manager = stream_ordering_manager()
         else:
-            stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
-                self, len(events_and_contexts)
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+                len(events_and_contexts)
             )
 
         with stream_ordering_manager as stream_orderings:
@@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore):
             stream_ordering = self.min_stream_token
 
         if stream_ordering is None:
-            stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+            stream_ordering_manager = self._stream_id_gen.get_next()
         else:
             @contextmanager
             def stream_ordering_manager():
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index de15741893..4cec31e316 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState",
 class PresenceStore(SQLBaseStore):
     @defer.inlineCallbacks
     def update_presence(self, presence_states):
-        stream_ordering_manager = yield self._presence_id_gen.get_next_mult(
-            self, len(presence_states)
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+            len(presence_states)
         )
 
         with stream_ordering_manager as stream_orderings:
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index bb5c14d912..56e69495b1 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore):
 
         if txn.rowcount == 0:
             # We didn't update a row with the given rule_id so insert one
-            push_rule_id = self._push_rule_id_gen.get_next_txn(txn)
+            push_rule_id = self._push_rule_id_gen.get_next()
 
             self._simple_insert_txn(
                 txn,
@@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore):
         defer.returnValue(ret)
 
     def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
-        new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
+        new_id = self._push_rules_enable_id_gen.get_next()
         self._simple_upsert_txn(
             txn,
             "push_rules_enable",
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index c23648cdbc..7693ab9082 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore):
                    app_display_name, device_display_name,
                    pushkey, pushkey_ts, lang, data, profile_tag=""):
         try:
-            next_id = yield self._pushers_id_gen.get_next()
+            next_id = self._pushers_id_gen.get_next()
             yield self._simple_upsert(
                 "pushers",
                 dict(
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 6567fa844f..dbc074d6b5 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        stream_id_manager = yield self._receipts_id_gen.get_next(self)
+        stream_id_manager = self._receipts_id_gen.get_next()
         with stream_id_manager as stream_id:
             have_persisted = yield self.runInteraction(
                 "insert_linearized_receipt",
@@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
             room_id, receipt_type, user_id, event_ids, data
         )
 
-        max_persisted_id = yield self._stream_id_gen.get_max_token()
+        max_persisted_id = self._stream_id_gen.get_max_token()
 
         defer.returnValue((stream_id, max_persisted_id))
 
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 03a9b66e4a..ad1157f979 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
         Raises:
             StoreError if there was a problem adding this.
         """
-        next_id = yield self._access_tokens_id_gen.get_next()
+        next_id = self._access_tokens_id_gen.get_next()
 
         yield self._simple_insert(
             "access_tokens",
@@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
         Raises:
             StoreError if there was a problem adding this.
         """
-        next_id = yield self._refresh_tokens_id_gen.get_next()
+        next_id = self._refresh_tokens_id_gen.get_next()
 
         yield self._simple_insert(
             "refresh_tokens",
@@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
     def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
         now = int(self.clock.time())
 
-        next_id = self._access_tokens_id_gen.get_next_txn(txn)
+        next_id = self._access_tokens_id_gen.get_next()
 
         try:
             if was_guest:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 372b540002..8ed8a21b0a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
             if event.is_state():
                 state_events[(event.type, event.state_key)] = event
 
-            state_group = self._state_groups_id_gen.get_next_txn(txn)
+            state_group = self._state_groups_id_gen.get_next()
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index b225f508a5..a0e6b42b30 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -195,12 +195,12 @@ class TagsStore(SQLBaseStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with (yield self._account_data_id_gen.get_next(self)) as next_id:
+        with self._account_data_id_gen.get_next() as next_id:
             yield self.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = yield self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_max_token()
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -217,12 +217,12 @@ class TagsStore(SQLBaseStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with (yield self._account_data_id_gen.get_next(self)) as next_id:
+        with self._account_data_id_gen.get_next() as next_id:
             yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = yield self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_max_token()
         defer.returnValue(result)
 
     def _update_revision_txn(self, txn, user_id, room_id, next_id):
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 4475c451c1..d338dfcf0a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
     def _prep_send_transaction(self, txn, transaction_id, destination,
                                origin_server_ts):
 
-        next_id = self._transaction_id_gen.get_next_txn(txn)
+        next_id = self._transaction_id_gen.get_next()
 
         # First we find out what the prev_txns should be.
         # Since we know that we are only sending one transaction at a time,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ef5e4a4668..efe3f68e6e 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,51 +13,30 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from collections import deque
 import contextlib
 import threading
 
 
 class IdGenerator(object):
-    def __init__(self, table, column, store):
+    def __init__(self, db_conn, table, column):
         self.table = table
         self.column = column
-        self.store = store
         self._lock = threading.Lock()
-        self._next_id = None
+        cur = db_conn.cursor()
+        self._next_id = self._load_next_id(cur)
+        cur.close()
 
-    @defer.inlineCallbacks
-    def get_next(self):
-        if self._next_id is None:
-            yield self.store.runInteraction(
-                "IdGenerator_%s" % (self.table,),
-                self.get_next_txn,
-            )
+    def _load_next_id(self, txn):
+        txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
+        val, = txn.fetchone()
+        return val + 1 if val else 1
 
+    def get_next(self):
         with self._lock:
             i = self._next_id
             self._next_id += 1
-            defer.returnValue(i)
-
-    def get_next_txn(self, txn):
-        with self._lock:
-            if self._next_id:
-                i = self._next_id
-                self._next_id += 1
-                return i
-            else:
-                txn.execute(
-                    "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
-                )
-
-                val, = txn.fetchone()
-                cur = val or 0
-                cur += 1
-                self._next_id = cur + 1
-
-                return cur
+            return i
 
 
 class StreamIdGenerator(object):
@@ -69,7 +48,7 @@ class StreamIdGenerator(object):
     persistence of events can complete out of order.
 
     Usage:
-        with stream_id_gen.get_next_txn(txn) as stream_id:
+        with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
     def __init__(self, db_conn, table, column):
@@ -79,15 +58,21 @@ class StreamIdGenerator(object):
         self._lock = threading.Lock()
 
         cur = db_conn.cursor()
-        self._current_max = self._get_or_compute_current_max(cur)
+        self._current_max = self._load_current_max(cur)
         cur.close()
 
         self._unfinished_ids = deque()
 
-    def get_next(self, store):
+    def _load_current_max(self, txn):
+        txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
+        rows = txn.fetchall()
+        val, = rows[0]
+        return int(val) if val else 1
+
+    def get_next(self):
         """
         Usage:
-            with yield stream_id_gen.get_next as stream_id:
+            with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -106,10 +91,10 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_next_mult(self, store, n):
+    def get_next_mult(self, n):
         """
         Usage:
-            with yield stream_id_gen.get_next(store, n) as stream_ids:
+            with stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -139,13 +124,3 @@ class StreamIdGenerator(object):
                 return self._unfinished_ids[0] - 1
 
             return self._current_max
-
-    def _get_or_compute_current_max(self, txn):
-        with self._lock:
-            txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
-            rows = txn.fetchall()
-            val, = rows[0]
-
-            self._current_max = int(val) if val else 1
-
-            return self._current_max