diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 338b495611..e2f9de8451 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
+import platform
+
from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
-import importlib
-
-
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
"psycopg2": PostgresEngine,
@@ -31,6 +31,10 @@ def create_engine(database_config):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
+ # pypy requires psycopg2cffi rather than psycopg2
+ if (name == "psycopg2" and
+ platform.python_implementation() == "PyPy"):
+ name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ txn.execute("SELECT nextval('state_group_id_seq')")
+ return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..19949fc474 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import prepare_database
-
import struct
+import threading
+
+from synapse.storage.prepare_database import prepare_database
class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
+ # The current max state_group, or None if we haven't looked
+ # in the DB yet.
+ self._current_state_group_id = None
+ self._current_state_group_id_lock = threading.Lock()
+
def check_database(self, txn):
pass
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._current_state_group_id_lock:
+ if self._current_state_group_id is None:
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ self._current_state_group_id = txn.fetchone()[0]
+
+ self._current_state_group_id += 1
+ return self._current_state_group_id
+
# Following functions taken from: https://github.com/coleifer/peewee
|