diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8166.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8179.misc b/changelog.d/8179.misc
new file mode 100644
index 0000000000..55bc079cdb
--- /dev/null
+++ b/changelog.d/8179.misc
@@ -0,0 +1 @@
+Add functions to `MultiWriterIdGen` used by events stream.
diff --git a/changelog.d/8187.misc b/changelog.d/8187.misc
new file mode 100644
index 0000000000..cb557122aa
--- /dev/null
+++ b/changelog.d/8187.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.storage.database`.
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index d68b4bd670..769cd5de28 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module.
import logging
+from synapse.federation.units import Transaction
from synapse.logging.utils import log_function
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -49,15 +51,15 @@ class TransactionActions(object):
return self.store.get_received_txn_response(transaction.transaction_id, origin)
@log_function
- def set_response(self, origin, transaction, code, response):
+ async def set_response(
+ self, origin: str, transaction: Transaction, code: int, response: JsonDict
+ ) -> None:
""" Persist how we responded to a transaction.
-
- Returns:
- Deferred
"""
- if not transaction.transaction_id:
+ transaction_id = transaction.transaction_id # type: ignore
+ if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")
- return self.store.set_received_txn_response(
- transaction.transaction_id, origin, code, response
+ await self.store.set_received_txn_response(
+ transaction_id, origin, code, response
)
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 6b32e0dcbf..64d98fc8f6 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject):
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]
- super(Transaction, self).__init__(
- transaction_id=transaction_id, pdus=pdus, **kwargs
- )
+ super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
@staticmethod
def create_new(pdus, **kwargs):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 2f6f49a4bf..ba4c0c9af6 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,7 +28,6 @@ from typing import (
Optional,
Tuple,
TypeVar,
- Union,
overload,
)
@@ -1655,7 +1654,7 @@ class DatabasePool(object):
term: Optional[str],
col: str,
retcols: Iterable[str],
- ) -> Union[List[Dict[str, Any]], int]:
+ ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1667,14 +1666,14 @@ class DatabasePool(object):
retcols: the names of the columns to return
Returns:
- 0 if no term is given, otherwise a list of dictionaries.
+ None if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
- return 0
+ return None
return cls.cursor_to_dict(txn)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 77723f7d4d..92f56f1602 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- def set_appservice_state(self, service, state):
+ async def set_appservice_state(self, service, state) -> None:
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
- Returns:
- An Awaitable which resolves when the state was set successfully.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a811a39eb5..ecd3f3b310 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore):
return {row["user_id"] for row in rows}
- def mark_remote_user_device_cache_as_stale(self, user_id: str):
+ async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e3ead71853..8acf254bf3 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_room_from_summary",
)
- def upsert_group_category(self, group_id, category_id, profile, is_public):
+ async def upsert_group_category(
+ self,
+ group_id: str,
+ category_id: str,
+ profile: Optional[JsonDict],
+ is_public: Optional[bool],
+ ) -> None:
"""Add/update room category for group
"""
insertion_values = {}
@@ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore):
else:
update_values["is_public"] = is_public
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_group_category",
)
- def upsert_group_role(self, group_id, role_id, profile, is_public):
+ async def upsert_group_role(
+ self,
+ group_id: str,
+ role_id: str,
+ profile: Optional[JsonDict],
+ is_public: Optional[bool],
+ ) -> None:
"""Add/remove user role
"""
insertion_values = {}
@@ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore):
else:
update_values["is_public"] = is_public
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_user_from_summary",
)
- def add_group_invite(self, group_id, user_id):
+ async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
@@ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_user_from_group", _remove_user_from_group_txn
)
- def add_room_to_group(self, group_id, room_id, is_public):
- return self.db_pool.simple_insert(
+ async def add_room_to_group(
+ self, group_id: str, room_id: str, is_public: bool
+ ) -> None:
+ await self.db_pool.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index fadcad51e7..1c0a049c55 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -140,22 +140,28 @@ class KeyStore(SQLBaseStore):
for i in invalidations:
invalidate((i,))
- def store_server_keys_json(
- self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
- ):
+ async def store_server_keys_json(
+ self,
+ server_name: str,
+ key_id: str,
+ from_server: str,
+ ts_now_ms: int,
+ ts_expires_ms: int,
+ key_json_bytes: bytes,
+ ) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
- server_name (str): The name of the server.
- key_id (str): The identifer of the key this JSON is for.
- from_server (str): The server this JSON was fetched from.
- ts_now_ms (int): The time now in milliseconds.
- ts_valid_until_ms (int): The time when this json stops being valid.
- key_json (bytes): The encoded JSON.
+ server_name: The name of the server.
+ key_id: The identifer of the key this JSON is for.
+ from_server: The server this JSON was fetched from.
+ ts_now_ms: The time now in milliseconds.
+ ts_valid_until_ms: The time when this json stops being valid.
+ key_json_bytes: The encoded JSON.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8361dd63d9..3919ecad69 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media",
)
- def store_local_media(
+ async def store_local_media(
self,
media_id,
media_type,
@@ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
user_id,
url_cache=None,
- ):
- return self.db_pool.simple_insert(
+ ) -> None:
+ await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
- def store_url_cache(
+ async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media_thumbnails",
)
- def store_local_thumbnail(
+ async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
@@ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_cached_remote_media",
)
- def store_cached_remote_media(
+ async def store_cached_remote_media(
self,
origin,
media_id,
@@ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
- def store_remote_media_thumbnail(
+ async def store_remote_media_thumbnail(
self,
origin,
media_id,
@@ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index dcd1ff911a..4db8949da7 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
- def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self.db_pool.simple_insert(
+ async def insert_open_id_token(
+ self, token: str, ts_valid_until_ms: int, user_id: str
+ ) -> None:
+ await self.db_pool.simple_insert(
table="open_id_tokens",
values={
"token": token,
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 5069f820bd..8b50e00553 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -125,8 +125,8 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
- def create_profile(self, user_localpart):
- return self.db_pool.simple_insert(
+ async def create_profile(self, user_localpart: str) -> None:
+ await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
@@ -197,8 +197,7 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore):
def __init__(self, database, db_conn, hs):
-
- super(ProfileStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"profile_replication_status_host_index",
@@ -208,13 +207,15 @@ class ProfileStore(ProfileWorkerStore):
unique=True,
)
- def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2772198fd9..c89d8863aa 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Any, Awaitable, Dict, List, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -628,23 +628,22 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="user_delete_threepids",
)
- def add_user_bound_threepid(self, user_id, medium, address, id_server):
+ async def add_user_bound_threepid(
+ self, user_id: str, medium: str, address: str, id_server: str
+ ):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str)
-
- Returns:
- Awaitable
+ user_id
+ medium
+ address
+ id_server
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -1162,9 +1161,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- def record_user_external_id(
+ async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
- ) -> Awaitable:
+ ) -> None:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1172,7 +1171,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@@ -1316,12 +1315,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return res if res else False
- def add_user_pending_deactivation(self, user_id):
+ async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 315e3d4936..c1d8ef5286 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -1318,11 +1318,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return self.db_pool.runInteraction("get_rooms", f)
- def add_event_report(
- self, room_id, event_id, user_id, reason, content, received_ts
- ):
+ async def add_event_report(
+ self,
+ room_id: str,
+ event_id: str,
+ user_id: str,
+ reason: str,
+ content: JsonDict,
+ received_ts: int,
+ ) -> None:
next_id = self._event_reports_id_gen.get_next()
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="event_reports",
values={
"id": next_id,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9fe97af56a..7af2608ca4 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from itertools import chain
-from typing import Tuple
+from typing import Any, Dict, Tuple
from twisted.internet.defer import DeferredLock
@@ -222,11 +222,11 @@ class StatsStore(StateDeltasStore):
desc="stats_incremental_position",
)
- def update_room_state(self, room_id, fields):
+ async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
"""
Args:
- room_id (str)
- fields (dict[str:Any])
+ room_id
+ fields
"""
# For whatever reason some of the fields may contain null bytes, which
@@ -244,7 +244,7 @@ class StatsStore(StateDeltasStore):
if field and "\0" in field:
fields[col] = None
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 52668dbdf9..2efcc0dc66 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
+from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
db_binary_type = memoryview
@@ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_received_txn_response(self, transaction_id, origin, code, response_dict):
- """Persist the response we returened for an incoming transaction, and
+ async def set_received_txn_response(
+ self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
+ ) -> None:
+ """Persist the response we returned for an incoming transaction, and
should return for subsequent transactions with the same transaction_id
and origin.
Args:
- txn
- transaction_id (str)
- origin (str)
- code (int)
- response_json (str)
+ transaction_id: The incoming transaction ID.
+ origin: The origin server.
+ code: The response code.
+ response_dict: The response, to be encoded into JSON.
"""
- return self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 5b07847773..b27a4843d0 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -343,6 +343,8 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
+ self._add_persisted_position(next_id)
+
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9b9a183e7f..14ce21c786 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -58,6 +58,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
def _insert(txn):
for _ in range(number):
txn.execute(
@@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+ def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id, updating
+ the postgres sequence position to match.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+ txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+
+ self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -188,11 +205,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
positions.
"""
- self._insert_rows("first", 3)
- self._insert_rows("second", 5)
+ # The following tests are a bit cheeky in that we notify about new
+ # positions via `advance` without *actually* advancing the postgres
+ # sequence.
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first")
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
@@ -218,3 +241,26 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+ def test_get_persisted_upto_position_get_next(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions when `get_next` is called.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # We assume that so long as `get_next` does correctly advance the
+ # `persisted_upto_position` in this case, then it will be correct in the
+ # other cases that are tested above (since they'll hit the same code).
|