diff --git a/changelog.d/8192.misc b/changelog.d/8192.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8192.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index fe30552c08..1d793d3deb 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import Dict, List
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
@@ -33,11 +33,11 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self.hs = hs
@cached(num_args=0)
- def get_monthly_active_count(self):
+ async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
Returns:
- Defered[int]: Number of current monthly active users
+ Number of current monthly active users
"""
def _count_users(txn):
@@ -46,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- return self.db_pool.runInteraction("count_users", _count_users)
+ return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
- def get_monthly_active_count_by_service(self):
+ async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
@@ -57,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
method to return anything other than native matrix users.
Returns:
- Deferred[dict]: dict that includes a mapping between app_service_id
- and the number of occurrences.
+ A mapping between app_service_id and the number of occurrences.
"""
@@ -74,7 +73,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
result = txn.fetchall()
return dict(result)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_users_by_service", _count_users_by_service
)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 7af2608ca4..9b9bc304a8 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,8 +15,9 @@
# limitations under the License.
import logging
+from collections import Counter
from itertools import chain
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.defer import DeferredLock
@@ -251,21 +252,23 @@ class StatsStore(StateDeltasStore):
desc="update_room_state",
)
- def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
+ async def get_statistics_for_subject(
+ self, stats_type: str, stats_id: str, start: str, size: int = 100
+ ) -> List[dict]:
"""
Get statistics for a given subject.
Args:
- stats_type (str): The type of subject
- stats_id (str): The ID of the subject (e.g. room_id or user_id)
- start (int): Pagination start. Number of entries, not timestamp.
- size (int): How many entries to return.
+ stats_type: The type of subject
+ stats_id: The ID of the subject (e.g. room_id or user_id)
+ start: Pagination start. Number of entries, not timestamp.
+ size: How many entries to return.
Returns:
- Deferred[list[dict]], where the dict has the keys of
+ A list of dicts, where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_statistics_for_subject",
self._get_statistics_for_subject_txn,
stats_type,
@@ -319,18 +322,17 @@ class StatsStore(StateDeltasStore):
allow_none=True,
)
- def bulk_update_stats_delta(self, ts, updates, stream_id):
+ async def bulk_update_stats_delta(
+ self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ ) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
Args:
- ts (int): Current timestamp in ms
- updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
- commit as a mapping stats_type -> stats_id -> field -> delta.
- stream_id (int): Current position.
-
- Returns:
- Deferred
+ ts: Current timestamp in ms
+ updates: The updates to commit as a mapping of
+ stats_type -> stats_id -> field -> delta.
+ stream_id: Current position.
"""
def _bulk_update_stats_delta_txn(txn):
@@ -355,38 +357,37 @@ class StatsStore(StateDeltasStore):
updatevalues={"stream_id": stream_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn
)
- def update_stats_delta(
+ async def update_stats_delta(
self,
- ts,
- stats_type,
- stats_id,
- fields,
- complete_with_stream_id,
- absolute_field_overrides=None,
- ):
+ ts: int,
+ stats_type: str,
+ stats_id: str,
+ fields: Dict[str, int],
+ complete_with_stream_id: Optional[int],
+ absolute_field_overrides: Optional[Dict[str, int]] = None,
+ ) -> None:
"""
Updates the statistics for a subject, with a delta (difference/relative
change).
Args:
- ts (int): timestamp of the change
- stats_type (str): "room" or "user" – the kind of subject
- stats_id (str): the subject's ID (room ID or user ID)
- fields (dict[str, int]): Deltas of stats values.
- complete_with_stream_id (int, optional):
+ ts: timestamp of the change
+ stats_type: "room" or "user" – the kind of subject
+ stats_id: the subject's ID (room ID or user ID)
+ fields: Deltas of stats values.
+ complete_with_stream_id:
If supplied, converts an incomplete row into a complete row,
with the supplied stream_id marked as the stream_id where the
row was completed.
- absolute_field_overrides (dict[str, int]): Current stats values
- (i.e. not deltas) of absolute fields.
- Does not work with per-slice fields.
+ absolute_field_overrides: Current stats values (i.e. not deltas) of
+ absolute fields. Does not work with per-slice fields.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_stats_delta",
self._update_stats_delta_txn,
ts,
@@ -646,19 +647,20 @@ class StatsStore(StateDeltasStore):
txn, into_table, all_dest_keyvalues, src_row
)
- def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
+ async def get_changes_room_total_events_and_bytes(
+ self, min_pos: int, max_pos: int
+ ) -> Dict[str, Dict[str, int]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
- min_pos (int)
- max_pos (int)
+ min_pos
+ max_pos
Returns:
- Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
- changes.
+ Mapping of room ID to field changes.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn,
min_pos,
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c01b04e1dc..4b3fb018b1 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -24,6 +24,7 @@ from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ side_effect=lambda: make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
@@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ side_effect=lambda: make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ side_effect=lambda: make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
yield defer.ensureDeferred(
@@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ side_effect=lambda: make_awaitable(self.small_number_of_users)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 5c92d0e8c9..eddf5e2498 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -15,8 +15,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
@@ -102,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value - 1)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -110,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ side_effect=lambda: make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -118,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -128,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ side_effect=lambda: make_awaitable(self.lots_of_users)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 17d0aae2e9..160c630235 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,8 +20,6 @@ import urllib.parse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -29,6 +27,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -338,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -592,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -632,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 23db821fb7..973338ea71 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(1000)
+ side_effect=lambda user_id: make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
@@ -158,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(None)
+ side_effect=lambda user_id: make_awaitable(None)
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -261,10 +261,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
+ self.store.get_monthly_active_count = Mock(
+ side_effect=lambda: make_awaitable(1000)
+ )
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(1000)
+ side_effect=lambda user_id: make_awaitable(1000)
)
# Call the function multiple times to ensure we only send the notice once
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 224ea6fd79..370c247e16 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -16,13 +16,12 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(lots_of_users)
+ side_effect=lambda: make_awaitable(lots_of_users)
)
self.get_success(
self.store.insert_client_ip(
|