summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-07-16 13:54:45 +0100
committerGitHub <noreply@github.com>2020-07-16 13:54:45 +0100
commita827838706b1b0c36afb1f6d6a44ba69751b2ec6 (patch)
tree5478da2bd4d2c7ed3d42c0450d2f6dad07da36cf /synapse
parentAdd some tiny type annotations (#7870) (diff)
parentchangelog (diff)
downloadsynapse-a827838706b1b0c36afb1f6d6a44ba69751b2ec6.tar.xz
Merge pull request #7866 from matrix-org/rav/fix_guest_user_id
Fix guest user registration with lots of client readers
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/register.py22
-rw-r--r--synapse/storage/data_stores/main/registration.py65
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py34
-rw-r--r--synapse/storage/data_stores/state/store.py12
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/util/id_generators.py8
-rw-r--r--synapse/storage/util/sequence.py98
9 files changed, 184 insertions, 80 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 78c3772ac1..501f0fe795 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -28,7 +28,6 @@ from synapse.replication.http.register import (
 )
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
-from synapse.util.async_helpers import Linearizer
 
 from ._base import BaseHandler
 
@@ -50,14 +49,7 @@ class RegistrationHandler(BaseHandler):
         self.user_directory_handler = hs.get_user_directory_handler()
         self.identity_handler = self.hs.get_handlers().identity_handler
         self.ratelimiter = hs.get_registration_ratelimiter()
-
-        self._next_generated_user_id = None
-
         self.macaroon_gen = hs.get_macaroon_generator()
-
-        self._generate_user_id_linearizer = Linearizer(
-            name="_generate_user_id_linearizer"
-        )
         self._server_notices_mxid = hs.config.server_notices_mxid
 
         if hs.config.worker_app:
@@ -219,7 +211,7 @@ class RegistrationHandler(BaseHandler):
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
 
-                localpart = await self._generate_user_id()
+                localpart = await self.store.generate_user_id()
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
                 self.check_user_id_not_appservice_exclusive(user_id)
@@ -510,18 +502,6 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    async def _generate_user_id(self):
-        if self._next_generated_user_id is None:
-            with await self._generate_user_id_linearizer.queue(()):
-                if self._next_generated_user_id is None:
-                    self._next_generated_user_id = (
-                        await self.store.find_next_generated_user_id_localpart()
-                    )
-
-        id = self._next_generated_user_id
-        self._next_generated_user_id += 1
-        return str(id)
-
     def check_registration_ratelimit(self, address):
         """A simple helper method to check whether the registration rate limit has been hit
         for a given IP address
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 587d4b91c1..27d2c5028c 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import Database
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore):
         self.config = hs.config
         self.clock = hs.get_clock()
 
+        self._user_id_seq = build_sequence_generator(
+            database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+        )
+
     @cached()
     def get_user_by_id(self, user_id):
         return self.db.simple_select_one(
@@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore):
         ret = yield self.db.runInteraction("count_real_users", _count_users)
         return ret
 
-    @defer.inlineCallbacks
-    def find_next_generated_user_id_localpart(self):
-        """
-        Gets the localpart of the next generated user ID.
+    async def generate_user_id(self) -> str:
+        """Generate a suitable localpart for a guest user
 
-        Generated user IDs are integers, so we find the largest integer user ID
-        already taken and return that plus one.
+        Returns: a (hopefully) free localpart
         """
-
-        def _find_next_generated_user_id(txn):
-            # We bound between '@0' and '@a' to avoid pulling the entire table
-            # out.
-            txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
-
-            regex = re.compile(r"^@(\d+):")
-
-            max_found = 0
-
-            for (user_id,) in txn:
-                match = regex.search(user_id)
-                if match:
-                    max_found = max(int(match.group(1)), max_found)
-
-            return max_found + 1
-
-        return (
-            (
-                yield self.db.runInteraction(
-                    "find_next_generated_user_id", _find_next_generated_user_id
-                )
-            )
+        next_id = await self.db.runInteraction(
+            "generate_user_id", self._user_id_seq.get_next_id_txn
         )
 
+        return str(next_id)
+
     async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
         """Returns user id from threepid
 
@@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             keyvalues={"user_id": user_id},
             values={"expiration_ts_ms": expiration_ts, "email_sent": False},
         )
+
+
+def find_max_generated_user_id_localpart(cur: Cursor) -> int:
+    """
+    Gets the localpart of the max current generated user ID.
+
+    Generated user IDs are integers, so we find the largest integer user ID
+    already taken and return that.
+    """
+
+    # We bound between '@0' and '@a' to avoid pulling the entire table
+    # out.
+    cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
+
+    regex = re.compile(r"^@(\d+):")
+
+    max_found = 0
+
+    for (user_id,) in cur:
+        match = regex.search(user_id)
+        if match:
+            max_found = max(int(match.group(1)), max_found)
+    return max_found
diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
new file mode 100644
index 0000000000..2011f6bceb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
@@ -0,0 +1,34 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adds a postgres SEQUENCE for generating guest user IDs.
+"""
+
+from synapse.storage.data_stores.main.registration import (
+    find_max_generated_user_id_localpart,
+)
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if not isinstance(database_engine, PostgresEngine):
+        return
+
+    next_id = find_max_generated_user_id_localpart(cur) + 1
+    cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 5db9f20135..128c09a2cf 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             "*stateGroupMembersCache*", 500000,
         )
 
+        def get_max_state_group_txn(txn: Cursor):
+            txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+            return txn.fetchone()[0]
+
+        self._state_group_seq_gen = build_sequence_generator(
+            self.database_engine, get_max_state_group_txn, "state_group_id_seq"
+        )
+
     @cached(max_entries=10000, iterable=True)
     def get_state_group_delta(self, state_group):
         """Given a state group try to return a previous group and a delta between
@@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 # AFAIK, this can never happen
                 raise Exception("current_state_ids cannot be None")
 
-            state_group = self.database_engine.get_next_state_group_id(txn)
+            state_group = self._state_group_seq_gen.get_next_id_txn(txn)
 
             self.db.simple_insert_txn(
                 txn,
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4bd3..908cbc79e3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
     def lock_table(self, txn, table: str) -> None:
         ...
 
-    @abc.abstractmethod
-    def get_next_state_group_id(self, txn) -> int:
-        """Returns an int that can be used as a new state_group ID
-        """
-        ...
-
     @property
     @abc.abstractmethod
     def server_version(self) -> str:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a31588080d..ff39281f85 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
 
-    def get_next_state_group_id(self, txn):
-        """Returns an int that can be used as a new state_group ID
-        """
-        txn.execute("SELECT nextval('state_group_id_seq')")
-        return txn.fetchone()[0]
-
     @property
     def server_version(self):
         """Returns a string giving the server version. For example: '8.1.5'
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a949442..8a0f8c89d1 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
     def lock_table(self, txn, table):
         return
 
-    def get_next_state_group_id(self, txn):
-        """Returns an int that can be used as a new state_group ID
-        """
-        # We do application locking here since if we're using sqlite then
-        # we are a single process synapse.
-        with self._current_state_group_id_lock:
-            if self._current_state_group_id is None:
-                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
-                self._current_state_group_id = txn.fetchone()[0]
-
-            self._current_state_group_id += 1
-            return self._current_state_group_id
-
     @property
     def server_version(self):
         """Gets a string giving the server version. For example: '3.22.0'
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..787cebfbec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
 from typing_extensions import Deque
 
 from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 
 class IdGenerator(object):
@@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
     ):
         self._db = db
         self._instance_name = instance_name
-        self._sequence_name = sequence_name
 
         # We lock as some functions may be called from DB threads.
         self._lock = threading.Lock()
@@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ) -> Dict[str, int]:
@@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
         return current_positions
 
     def _load_next_id_txn(self, txn):
-        txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        (next_id,) = txn.fetchone()
-        return next_id
+        return self._sequence_gen.get_next_id_txn(txn)
 
     async def get_next(self):
         """
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..63dfea4220
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+    """A class which generates a unique sequence of integers"""
+
+    @abc.abstractmethod
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        """Gets the next ID in the sequence"""
+        ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+    """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+    def __init__(self, sequence_name: str):
+        self._sequence_name = sequence_name
+
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        return txn.fetchone()[0]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+    """An implementation of SequenceGenerator which uses local locking
+
+    This only works reliably if there are no other worker processes generating IDs at
+    the same time.
+    """
+
+    def __init__(self, get_first_callback: GetFirstCallbackType):
+        """
+        Args:
+            get_first_callback: a callback which is called on the first call to
+                 get_next_id_txn; should return the curreent maximum id
+        """
+        # the callback. this is cleared after it is called, so that it can be GCed.
+        self._callback = get_first_callback  # type: Optional[GetFirstCallbackType]
+
+        # The current max value, or None if we haven't looked in the DB yet.
+        self._current_max_id = None  # type: Optional[int]
+        self._lock = threading.Lock()
+
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            self._current_max_id += 1
+            return self._current_max_id
+
+
+def build_sequence_generator(
+    database_engine: BaseDatabaseEngine,
+    get_first_callback: GetFirstCallbackType,
+    sequence_name: str,
+) -> SequenceGenerator:
+    """Get the best impl of SequenceGenerator available
+
+    This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+    sqlite.
+
+    Args:
+        database_engine: the database engine we are connected to
+        get_first_callback: a callback which gets the next sequence ID. Used if
+            we're on sqlite.
+        sequence_name: the name of a postgres sequence to use.
+    """
+    if isinstance(database_engine, PostgresEngine):
+        return PostgresSequenceGenerator(sequence_name)
+    else:
+        return LocalSequenceGenerator(get_first_callback)