diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 2193d8fdc5..cf039e7f7d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,13 +18,12 @@ import abc
import logging
from typing import List, Tuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -327,7 +326,7 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
A deferred that completes once the account_data has been added.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
@@ -373,7 +372,7 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
A deferred that completes once the account_data has been added.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 683afde52b..10de446065 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
- self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 712c8d0264..216a5925fc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,8 +14,7 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
@@ -82,21 +81,19 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
"devices_last_seen", self._devices_last_seen_update
)
- @defer.inlineCallbacks
- def _remove_user_ip_nonunique(self, progress, batch_size):
+ async def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
- yield self.db_pool.runWithConnection(f)
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.runWithConnection(f)
+ await self.db_pool.updates._end_background_update(
"user_ips_drop_nonunique_index"
)
return 1
- @defer.inlineCallbacks
- def _analyze_user_ip(self, progress, batch_size):
+ async def _analyze_user_ip(self, progress, batch_size):
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
@@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
+ await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
- yield self.db_pool.updates._end_background_update("user_ips_analyze")
+ await self.db_pool.updates._end_background_update("user_ips_analyze")
return 1
- @defer.inlineCallbacks
- def _remove_user_ip_dupes(self, progress, batch_size):
+ async def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
@@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
- end_last_seen = yield self.db_pool.runInteraction(
+ end_last_seen = await self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@@ -275,15 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.db_pool.runInteraction("user_ips_dups_remove", remove)
+ await self.db_pool.runInteraction("user_ips_dups_remove", remove)
if last:
- yield self.db_pool.updates._end_background_update("user_ips_remove_dupes")
+ await self.db_pool.updates._end_background_update("user_ips_remove_dupes")
return batch_size
- @defer.inlineCallbacks
- def _devices_last_seen_update(self, progress, batch_size):
+ async def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
@@ -346,12 +341,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return len(rows)
- updated = yield self.db_pool.runInteraction(
+ updated = await self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
- yield self.db_pool.updates._end_background_update("devices_last_seen")
+ await self.db_pool.updates._end_background_update("devices_last_seen")
return updated
@@ -380,8 +375,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
- @defer.inlineCallbacks
- def insert_client_ip(
+ async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now:
@@ -392,7 +386,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- yield self.populate_monthly_active_users(user_id)
+ await self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@@ -461,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
- @defer.inlineCallbacks
- def get_last_client_ip_by_device(self, user_id, device_id):
+ async def get_last_client_ip_by_device(
+ self, user_id: str, device_id: Optional[str]
+ ) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on
Args:
- user_id (str)
- device_id (str): If None fetches all devices for the user
+ user_id: The user to fetch devices for.
+ device_id: If None fetches all devices for the user
Returns:
- defer.Deferred: resolves to a dict, where the keys
- are (user_id, device_id) tuples. The values are also dicts, with
- keys giving the column names
+ A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+ keys giving the column names from the devices table.
"""
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
- res = yield self.db_pool.simple_select_list(
+ res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -501,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
return ret
- @defer.inlineCallbacks
- def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}
@@ -512,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self.db_pool.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 874ecdf8d2..76ec954f44 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -16,13 +16,12 @@
import logging
from typing import List, Tuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
+from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@@ -354,7 +353,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = json.dumps(edu)
+ edu_json = json_encoder.encode(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
@@ -432,7 +431,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Handle wildcard device_ids.
sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
- message_json = json.dumps(messages_by_device["*"])
+ message_json = json_encoder.encode(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -454,7 +453,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = json.dumps(messages_by_device[device])
+ message_json = json_encoder.encode(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 88a7aadfc6..7a5f0bab05 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -17,8 +17,6 @@
import logging
from typing import List, Optional, Set, Tuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError
@@ -36,6 +34,7 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import (
Cache,
cached,
@@ -137,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
- cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ cross_signing_key = yield defer.ensureDeferred(
+ self.get_e2e_cross_signing_key(user, "master")
+ )
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
@@ -150,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- cross_signing_key = yield self.get_e2e_cross_signing_key(
- user, "self_signing"
+ cross_signing_key = yield defer.ensureDeferred(
+ self.get_e2e_cross_signing_key(user, "self_signing")
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
@@ -247,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
- user_id/device_id to update stream_id and the relevent json-encoded
+ user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
Returns:
@@ -397,7 +398,7 @@ class DeviceWorkerStore(SQLBaseStore):
values={
"stream_id": stream_id,
"from_user_id": from_user_id,
- "user_ids": json.dumps(user_ids),
+ "user_ids": json_encoder.encode(user_ids),
},
)
@@ -600,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore):
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
- function to get further updatees.
+ function to get further updates.
The updates are a list of 2-tuples of stream ID and the row data
"""
@@ -1032,7 +1033,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={"content": json.dumps(content)},
+ values={"content": json_encoder.encode(content)},
# we don't need to lock, because we assume we are the only thread
# updating this user's devices.
lock=False,
@@ -1088,7 +1089,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
{
"user_id": user_id,
"device_id": content["device_id"],
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
}
for content in devices
],
@@ -1209,7 +1210,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
- "opentracing_context": json.dumps(context)
+ "opentracing_context": json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
}
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 7819bfcbb3..037e02603c 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,30 +14,29 @@
# limitations under the License.
from collections import namedtuple
-from typing import Optional
-
-from twisted.internet import defer
+from typing import Iterable, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
+from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias):
- """ Get's the room_id and server list for a given room_alias
+ async def get_association_from_room_alias(
+ self, room_alias: RoomAlias
+ ) -> Optional[RoomAliasMapping]:
+ """Gets the room_id and server list for a given room_alias
Args:
- room_alias (RoomAlias)
+ room_alias: The alias to translate to an ID.
Returns:
- Deferred: results in namedtuple with keys "room_id" and
- "servers" or None if no association can be found
+ The room alias mapping or None if no association can be found.
"""
- room_id = yield self.db_pool.simple_select_one_onecol(
+ room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
- servers = yield self.db_pool.simple_select_onecol(
+ servers = await self.db_pool.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
- @defer.inlineCallbacks
- def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
+ async def create_room_alias_association(
+ self,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Iterable[str],
+ creator: Optional[str] = None,
+ ) -> None:
""" Creates an association between a room alias and room_id/servers
Args:
- room_alias (RoomAlias)
- room_id (str)
- servers (list)
- creator (str): Optional user_id of creator.
-
- Returns:
- Deferred
+ room_alias: The alias to create.
+ room_id: The target of the alias.
+ servers: A list of servers through which it may be possible to join the room
+ creator: Optional user_id of creator.
"""
def alias_txn(txn):
@@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- return ret
- @defer.inlineCallbacks
- def delete_room_alias(self, room_alias):
- room_id = yield self.db_pool.runInteraction(
+ async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+ room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
- def _delete_room_alias_txn(self, txn, room_alias):
+ def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 90152edc3c..2eeb9f97dc 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,18 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util import json_encoder
class EndToEndRoomKeyStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ async def update_e2e_room_key(
+ self, user_id, version, room_id, session_id, room_key
+ ):
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
@@ -38,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -50,13 +48,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
- "session_data": json.dumps(room_key["session_data"]),
+ "session_data": json_encoder.encode(room_key["session_data"]),
},
desc="update_e2e_room_key",
)
- @defer.inlineCallbacks
- def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.
Args:
@@ -77,7 +74,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
- "session_data": json.dumps(room_key["session_data"]),
+ "session_data": json_encoder.encode(room_key["session_data"]),
}
)
log_kv(
@@ -89,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}
)
- yield self.db_pool.simple_insert_many(
+ await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
- @defer.inlineCallbacks
- def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
@@ -110,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -125,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = yield self.db_pool.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -243,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- @defer.inlineCallbacks
- def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_e2e_room_keys(
+ self, user_id, version, room_id=None, session_id=None
+ ):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
@@ -259,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred of the deletion transaction
+ The deletion transaction
"""
keyvalues = {"user_id": user_id, "version": int(version)}
@@ -268,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- yield self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -360,7 +357,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"user_id": user_id,
"version": new_version,
"algorithm": info["algorithm"],
- "auth_data": json.dumps(info["auth_data"]),
+ "auth_data": json_encoder.encode(info["auth_data"]),
},
)
@@ -387,7 +384,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues = {}
if info is not None and "auth_data" in info:
- updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
if version_etag is not None:
updatevalues["etag"] = version_etag
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 40354b8304..f93e0d320d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,24 +14,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
-from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
- @defer.inlineCallbacks
- def get_e2e_device_keys(
+ async def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
@@ -51,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = yield self.db_pool.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
@@ -174,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv(result)
return result
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ async def get_e2e_one_time_keys(
+ self, user_id: str, device_id: str, key_ids: List[str]
+ ) -> Dict[Tuple[str, str], str]:
"""Retrieve a number of one-time keys for a user
Args:
@@ -185,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
retrieve
Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
+ A map from (algorithm, key_id) to json string for key
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@@ -201,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ async def add_e2e_one_time_keys(
+ self,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
"""Insert some new one time keys for a device. Errors if any of the
keys already exist.
Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn):
@@ -241,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@@ -268,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- @defer.inlineCallbacks
- def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ async def get_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, from_user_id: Optional[str] = None
+ ) -> Optional[dict]:
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
+ user_id: the user whose key is being requested
+ key_type: the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
+ from_user_id: if specified, signatures made by this user on
the self-signing key will be included in the result
Returns:
dict of the key data or None if not found
"""
- res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id)
if not user_keys:
return None
@@ -449,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return keys
- @defer.inlineCallbacks
- def get_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str], from_user_id: str = None
- ) -> defer.Deferred:
+ async def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: Optional[str] = None
+ ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users.
Args:
- user_ids (list[str]): the users whose keys are being requested
- from_user_id (str): if specified, signatures made by this user on
+ user_ids: the users whose keys are being requested
+ from_user_id: if specified, signatures made by this user on
the self-signing keys will be included in the result
Returns:
- Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
- key data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A map of user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict,
+ or their user ID will map to None.
"""
- result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+ result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
@@ -700,7 +702,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
values={
"user_id": user_id,
"keytype": key_type,
- "keydata": json.dumps(key),
+ "keydata": json_encoder.encode(key),
"stream_id": stream_id,
},
)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b8cefb4d5e..7c246d3e4c 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -17,11 +17,10 @@
import logging
from typing import List
-from canonicaljson import json
-
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -50,7 +49,7 @@ def _serialize_action(actions, is_highlight):
else:
if actions == DEFAULT_NOTIF_ACTION:
return ""
- return json.dumps(actions)
+ return json_encoder.encode(actions)
def _deserialize_action(actions, is_highlight):
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4d8a24ce4b..1a68bf32cb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -53,47 +53,6 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
-STATE_EVENT_TYPES_TO_MARK_UNREAD = {
- EventTypes.Topic,
- EventTypes.Name,
- EventTypes.RoomAvatar,
- EventTypes.Tombstone,
-}
-
-
-def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
- # Exclude rejected and soft-failed events.
- if context.rejected or event.internal_metadata.is_soft_failed():
- return False
-
- # Exclude notices.
- if (
- not event.is_state()
- and event.type == EventTypes.Message
- and event.content.get("msgtype") == "m.notice"
- ):
- return False
-
- # Exclude edits.
- relates_to = event.content.get("m.relates_to", {})
- if relates_to.get("rel_type") == RelationTypes.REPLACE:
- return False
-
- # Mark events that have a non-empty string body as unread.
- body = event.content.get("body")
- if isinstance(body, str) and body:
- return True
-
- # Mark some state events as unread.
- if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
- return True
-
- # Mark encrypted events as unread.
- if not event.is_state() and event.type == EventTypes.Encrypted:
- return True
-
- return False
-
def encode_json(json_object):
"""
@@ -239,10 +198,6 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- self.store.get_unread_message_count_for_user.invalidate_many(
- (event.room_id,),
- )
-
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -864,9 +819,8 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
- "count_as_unread": should_count_as_unread(event, context),
}
- for event, context in events_and_contexts
+ for event, _ in events_and_contexts
],
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a7b7393f6e..755b7a2a85 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import (
- Cache,
- _CacheContext,
- cached,
- cachedInlineCallbacks,
-)
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
- @cached(tree=True, cache_context=True)
- async def get_unread_message_count_for_user(
- self, room_id: str, user_id: str, cache_context: _CacheContext,
- ) -> int:
- """Retrieve the count of unread messages for the given room and user.
-
- Args:
- room_id: The ID of the room to count unread messages in.
- user_id: The ID of the user to count unread messages for.
-
- Returns:
- The number of unread messages for the given user in the given room.
- """
- with Measure(self._clock, "get_unread_message_count_for_user"):
- last_read_event_id = await self.get_last_receipt_event_id_for_user(
- user_id=user_id,
- room_id=room_id,
- receipt_type="m.read",
- on_invalidate=cache_context.invalidate,
- )
-
- return await self.db_pool.runInteraction(
- "get_unread_message_count_for_user",
- self._get_unread_message_count_for_user_txn,
- user_id,
- room_id,
- last_read_event_id,
- )
-
- def _get_unread_message_count_for_user_txn(
- self,
- txn: Cursor,
- user_id: str,
- room_id: str,
- last_read_event_id: Optional[str],
- ) -> int:
- if last_read_event_id:
- # Get the stream ordering for the last read event.
- stream_ordering = self.db_pool.simple_select_one_onecol_txn(
- txn=txn,
- table="events",
- keyvalues={"room_id": room_id, "event_id": last_read_event_id},
- retcol="stream_ordering",
- )
- else:
- # If there's no read receipt for that room, it probably means the user hasn't
- # opened it yet, in which case use the stream ID of their join event.
- # We can't just set it to 0 otherwise messages from other local users from
- # before this user joined will be counted as well.
- txn.execute(
- """
- SELECT stream_ordering FROM local_current_membership
- LEFT JOIN events USING (event_id, room_id)
- WHERE membership = 'join'
- AND user_id = ?
- AND room_id = ?
- """,
- (user_id, room_id),
- )
- row = txn.fetchone()
-
- if row is None:
- return 0
-
- stream_ordering = row[0]
-
- # Count the messages that qualify as unread after the stream ordering we've just
- # retrieved.
- sql = """
- SELECT COUNT(*) FROM events
- WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
- """
-
- txn.execute(sql, (user_id, room_id, stream_ordering))
- row = txn.fetchone()
-
- return row[0] if row else 0
-
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index a98181f445..75ea6d4b2f 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -16,12 +16,11 @@
from typing import List, Tuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util import json_encoder
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -752,7 +751,7 @@ class GroupServerStore(GroupServerWorkerStore):
if profile is None:
insertion_values["profile"] = "{}"
else:
- update_values["profile"] = json.dumps(profile)
+ update_values["profile"] = json_encoder.encode(profile)
if is_public is None:
insertion_values["is_public"] = True
@@ -783,7 +782,7 @@ class GroupServerStore(GroupServerWorkerStore):
if profile is None:
insertion_values["profile"] = "{}"
else:
- update_values["profile"] = json.dumps(profile)
+ update_values["profile"] = json_encoder.encode(profile)
if is_public is None:
insertion_values["is_public"] = True
@@ -1007,7 +1006,7 @@ class GroupServerStore(GroupServerWorkerStore):
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json.dumps(remote_attestation),
+ "attestation_json": json_encoder.encode(remote_attestation),
},
)
@@ -1131,7 +1130,7 @@ class GroupServerStore(GroupServerWorkerStore):
"is_admin": is_admin,
"membership": membership,
"is_publicised": is_publicised,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
)
@@ -1143,7 +1142,7 @@ class GroupServerStore(GroupServerWorkerStore):
"group_id": group_id,
"user_id": user_id,
"type": "membership",
- "content": json.dumps(
+ "content": json_encoder.encode(
{"membership": membership, "content": content}
),
},
@@ -1171,7 +1170,7 @@ class GroupServerStore(GroupServerWorkerStore):
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json.dumps(remote_attestation),
+ "attestation_json": json_encoder.encode(remote_attestation),
},
)
else:
@@ -1240,7 +1239,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
"valid_until_ms": attestation["valid_until_ms"],
- "attestation_json": json.dumps(attestation),
+ "attestation_json": json_encoder.encode(attestation),
},
desc="update_remote_attestion",
)
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 02b01d9619..e71cdd2cb4 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,8 +15,6 @@
import logging
from typing import List
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
@@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
- @defer.inlineCallbacks
- def upsert_monthly_active_user(self, user_id):
+ async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
Args:
- user_id (str): user to add/update
-
- Returns:
- Deferred
+ user_id: user to add/update
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# _initialise_reserved_users reasoning that it would be very strange to
# include a support user in this context.
- is_support = yield self.is_support_user(user_id)
+ is_support = await self.is_support_user(user_id)
if is_support:
return
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
@@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
return is_insert
- @defer.inlineCallbacks
- def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = yield self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id)
if is_guest:
return
- is_trial = yield self.is_trial_user(user_id)
+ is_trial = await self.is_trial_user(user_id)
if is_trial:
return
- last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
@@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
if not self._limit_usage_by_mau:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
else:
- count = yield self.get_monthly_active_count()
+ count = await self.get_monthly_active_count()
if count < self._max_mau_value:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 5fd899326a..19a0211a03 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
@@ -33,6 +31,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -419,8 +418,8 @@ class PushRuleStore(PushRulesWorkerStore):
before=None,
after=None,
):
- conditions_json = json.dumps(conditions)
- actions_json = json.dumps(actions)
+ conditions_json = json_encoder.encode(conditions)
+ actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
@@ -689,7 +688,7 @@ class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
- actions_json = json.dumps(actions)
+ actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 6255977c92..1920a8a152 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -18,13 +18,12 @@ import abc
import logging
from typing import List, Tuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -459,7 +458,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
values={
"stream_id": stream_id,
"event_id": event_id,
- "data": json.dumps(data),
+ "data": json_encoder.encode(data),
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
@@ -585,7 +584,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
- "event_ids": json.dumps(event_ids),
- "data": json.dumps(data),
+ "event_ids": json_encoder.encode(event_ids),
+ "data": json_encoder.encode(data),
},
)
diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql
deleted file mode 100644
index 531b532c73..0000000000
--- a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql
+++ /dev/null
@@ -1,18 +0,0 @@
-/* Copyright 2020 The Matrix.org Foundation C.I.C
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
--- Store a boolean value in the events table for whether the event should be counted in
--- the unread_count property of sync responses.
-ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2162d0712d..7f8d1880e5 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,8 +16,7 @@
import logging
import re
from collections import namedtuple
-
-from twisted.internet import defer
+from typing import List, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -114,8 +113,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
- @defer.inlineCallbacks
- def _background_reindex_search(self, progress, batch_size):
+ async def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -206,19 +204,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return len(event_search_rows)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_reindex_gin_search(self, progress, batch_size):
+ async def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
@@ -255,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
- yield self.db_pool.runWithConnection(create_index)
+ await self.db_pool.runWithConnection(create_index)
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
)
return 1
- @defer.inlineCallbacks
- def _background_reindex_search_order(self, progress, batch_size):
+ async def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -288,12 +284,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
conn.set_session(autocommit=False)
- yield self.db_pool.runWithConnection(create_index)
+ await self.db_pool.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self.db_pool.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
@@ -331,12 +327,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return len(rows), True
- num_rows, finished = yield self.db_pool.runInteraction(
+ num_rows, finished = await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
if not finished:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME
)
@@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)
- @defer.inlineCallbacks
- def search_msgs(self, room_ids, search_term, keys):
+ async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
@@ -425,7 +420,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self.db_pool.execute(
+ results = await self.db_pool.execute(
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
)
@@ -433,7 +428,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -442,11 +437,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
- count_results = yield self.db_pool.execute(
+ count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)
@@ -462,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- @defer.inlineCallbacks
- def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
+ async def search_rooms(
+ self,
+ room_ids: List[str],
+ search_term: str,
+ keys: List[str],
+ limit,
+ pagination_token: Optional[str] = None,
+ ) -> List[dict]:
"""Performs a full text search over events with given keys.
Args:
- room_id (list): The room_ids to search in
- search_term (str): Search term to search for
- keys (list): List of keys to search in, currently supports
- "content.body", "content.name", "content.topic"
- pagination_token (str): A pagination token previously returned
+ room_ids: The room_ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports "content.body",
+ "content.name", "content.topic"
+ pagination_token: A pagination token previously returned
Returns:
- list of dicts
+ Each match as a dictionary.
"""
clauses = []
@@ -577,7 +578,7 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit)
- results = yield self.db_pool.execute(
+ results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
)
@@ -585,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -594,11 +595,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
- count_results = yield self.db_pool.execute(
+ count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index dae8e8bd29..be191dd870 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -15,8 +15,6 @@
from unpaddedbase64 import encode_base64
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -40,9 +38,8 @@ class SignatureWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction("get_event_reference_hashes", f)
- @defer.inlineCallbacks
- def add_event_hashes(self, event_ids):
- hashes = yield self.get_event_reference_hashes(event_ids)
+ async def add_event_hashes(self, event_ids):
+ hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d73a8e8ab9..af21fe457a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,8 +16,6 @@
import logging
import re
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state import StateFilter
@@ -59,8 +57,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
- @defer.inlineCallbacks
- def _populate_user_directory_createtables(self, progress, batch_size):
+ async def _populate_user_directory_createtables(self, progress, batch_size):
# Get all the rooms that we want to process.
def _make_staging_area(txn):
@@ -102,45 +99,43 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
- new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.db_pool.runInteraction(
+ new_pos = await self.get_max_stream_id_in_current_state_deltas()
+ await self.db_pool.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
TEMP_TABLE + "_position", {"position": new_pos}
)
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_createtables"
)
return 1
- @defer.inlineCallbacks
- def _populate_user_directory_cleanup(self, progress, batch_size):
+ async def _populate_user_directory_cleanup(self, progress, batch_size):
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self.db_pool.simple_select_one_onecol(
+ position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
- yield self.update_user_directory_stream_pos(position)
+ await self.update_user_directory_stream_pos(position)
def _delete_staging_area(txn):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_cleanup"
)
return 1
- @defer.inlineCallbacks
- def _populate_user_directory_process_rooms(self, progress, batch_size):
+ async def _populate_user_directory_process_rooms(self, progress, batch_size):
"""
Args:
progress (dict)
@@ -151,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If we don't have progress filed, delete everything.
if not progress:
- yield self.delete_all_from_user_dir()
+ await self.delete_all_from_user_dir()
def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even
@@ -176,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return rooms_to_work_on
- rooms_to_work_on = yield self.db_pool.runInteraction(
+ rooms_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_rooms"
)
return 1
@@ -195,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
- is_in_room = yield self.is_host_joined(room_id, self.server_name)
+ is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room:
- is_public = yield self.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.is_room_world_readable_or_publicly_joinable(
room_id
)
- users_with_profile = yield defer.ensureDeferred(
- state.get_current_users_in_room(room_id)
- )
+ users_with_profile = await state.get_current_users_in_room(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.
for user_id, profile in users_with_profile.items():
- yield self.update_profile_in_user_dir(
+ await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -223,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
to_insert.add(user_id)
if to_insert:
- yield self.add_users_in_public_rooms(room_id, to_insert)
+ await self.add_users_in_public_rooms(room_id, to_insert)
to_insert.clear()
else:
for user_id in user_ids:
@@ -243,22 +236,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If it gets too big, stop and write to the database
# to prevent storing too much in RAM.
if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
- yield self.add_users_who_share_private_room(
+ await self.add_users_who_share_private_room(
room_id, to_insert
)
to_insert.clear()
if to_insert:
- yield self.add_users_who_share_private_room(room_id, to_insert)
+ await self.add_users_who_share_private_room(room_id, to_insert)
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self.db_pool.simple_delete_one(
+ await self.db_pool.simple_delete_one(
TEMP_TABLE + "_rooms", {"room_id": room_id}
)
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory",
self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_rooms",
@@ -273,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
- @defer.inlineCallbacks
- def _populate_user_directory_process_users(self, progress, batch_size):
+ async def _populate_user_directory_process_users(self, progress, batch_size):
"""
If search_all_users is enabled, add all of the users to the user directory.
"""
if not self.hs.config.user_directory_search_all_users:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
@@ -305,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return users_to_work_on
- users_to_work_on = yield self.db_pool.runInteraction(
+ users_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more users -- complete the transaction.
if not users_to_work_on:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
@@ -322,18 +314,18 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for user_id in users_to_work_on:
- profile = yield self.get_profileinfo(get_localpart_from_id(user_id))
- yield self.update_profile_in_user_dir(
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
# We've finished processing a user. Delete it from the table.
- yield self.db_pool.simple_delete_one(
+ await self.db_pool.simple_delete_one(
TEMP_TABLE + "_users", {"user_id": user_id}
)
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory",
self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_users",
@@ -342,8 +334,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
- @defer.inlineCallbacks
- def is_room_world_readable_or_publicly_joinable(self, room_id):
+ async def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
"""
@@ -353,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = yield self.get_filtered_current_state_ids(
+ current_state_ids = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
- join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
- hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
return True
@@ -590,19 +581,18 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn
)
- @defer.inlineCallbacks
- def get_users_in_dir_due_to_room(self, room_id):
+ async def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_share_pub = yield self.db_pool.simple_select_onecol(
+ user_ids_share_pub = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share_priv = yield self.db_pool.simple_select_onecol(
+ user_ids_share_priv = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
@@ -645,8 +635,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- @defer.inlineCallbacks
- def get_user_dir_rooms_user_is_in(self, user_id):
+ async def get_user_dir_rooms_user_is_in(self, user_id):
"""
Returns the rooms that a user is in.
@@ -656,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self.db_pool.simple_select_onecol(
+ rows = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self.db_pool.simple_select_onecol(
+ pub_rows = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -674,32 +663,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- @defer.inlineCallbacks
- def get_rooms_in_common_for_users(self, user_id, other_user_id):
- """Given two user_ids find out the list of rooms they share.
- """
- sql = """
- SELECT room_id FROM (
- SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (event_id)
- WHERE type = 'm.room.member'
- AND m.membership = 'join'
- AND state_key = ?
- ) AS f1 INNER JOIN (
- SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (event_id)
- WHERE type = 'm.room.member'
- AND m.membership = 'join'
- AND state_key = ?
- ) f2 USING (room_id)
- """
-
- rows = yield self.db_pool.execute(
- "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
- )
-
- return [room_id for room_id, in rows]
-
def get_user_directory_stream_pos(self):
return self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
@@ -708,8 +671,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
- @defer.inlineCallbacks
- def search_user_dir(self, user_id, search_term, limit):
+ async def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
Returns:
@@ -806,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self.db_pool.execute(
+ results = await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
)
|