diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..b0353ac2dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +101,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
-
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
- return manager()
+ return _MultiWriterCtxManager(self)
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ return False
|